├── .gitignore ├── README.md ├── dataset_creation ├── .gitignore ├── NLRBScraper.ipynb ├── bar_outlines │ └── scrape_bar_exam_outlines.py ├── bva │ └── process_bva.py ├── canadian_decisions │ ├── grab_canadian_decisions.ipynb │ └── post_process.py ├── collect_atticus_contracts.py ├── congressional_hearings │ └── scrape_congressional_hearings.py ├── constitutions │ └── scrape_constitutions.py ├── courtlistener │ ├── scrape_and_process_cl_data.py │ └── scrape_and_process_cl_docket_data.py ├── creative_commons_casebooks │ └── scrape_cc_casebooks.py ├── creditcardcfpb │ └── scrape_cfbp_cc_agreements.py ├── dol_ecab │ └── scrape_dol_ecab.py ├── echr │ └── scrape_echr.py ├── edgar_contracts │ └── process_edgar_contracts_dataset.py ├── eoir │ └── scrape_eoir.py ├── euro_parl.py ├── eurolex │ └── scrape_eurolex.py ├── federal_register │ └── process_federal_register.py ├── federal_rules │ ├── frcp.py │ └── fre.py ├── founder_docs │ └── scrape_founding_docs.py ├── ftcadvisories │ └── scrape_ftc_advisories.py ├── oig_reports │ └── scrape_oig_reports.py ├── olc │ └── scrape_olc_memos.py ├── process_cfr.py ├── process_scotus_oral_arguments.py ├── r_legaladvice │ └── scrape_r_legaladvice.py ├── scotus_dockets │ └── scrape_scotus_dockets.py ├── scrape_nlrb_decisions.py ├── state_codes │ ├── process_existing_state_code_files.py │ └── state_codes_from_scratch.py ├── tax_corpus_jhu │ └── process_tax_corpus_jhu.py ├── tos │ └── process_tos.py ├── un_debates │ └── scrape_un_debates.py ├── us_bills │ ├── process_us_bills.py │ ├── run.sh │ └── split_us_bills.py └── uscode │ └── scrape_us_code.py ├── pretraining ├── README.md ├── chunkify_and_hd5.py ├── new_vocab.py └── pol-finetuning-main │ ├── README.md │ ├── environment.yml │ ├── scripts │ └── run_casehold.sh │ └── tasks │ ├── adam_bias_correct.py │ ├── casehold.py │ ├── casehold_bc.py │ └── casehold_helpers.py ├── privacy ├── README.md ├── eoir │ ├── EOIR.ipynb │ ├── EOIR_validation_exp.ipynb │ ├── README.md │ ├── causal_exp.py │ └── create_pseudonyms_dataset.py └── janedoe │ ├── README.md │ ├── jane_doe.py │ ├── jane_doe_negative.py │ └── jane_doe_plot.py ├── scrub ├── README.md ├── pii.py └── scrub.py └── toxicity ├── README.md ├── create_doc_sent_index.py ├── fig2 ├── code │ └── fig2.R └── data │ ├── SCDB.Rdata │ ├── SCDBLegacy.Rdata │ ├── SCDB_2018_02_caseCentered_Citation.csv │ ├── SCDB_2019_01_caseCentered_Citation.Rdata │ └── SCDB_crosswalk.csv ├── scotus_only_pc_only.py ├── scotus_only_perspective_pc.py ├── scotus_only_toxigen.py ├── scotus_only_unitary.py └── toxigen_context_exp.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | dataset_creation/*/cache/ 131 | dataset_creation/*/cache_bak 132 | .DS_Store 133 | .tmp 134 | scratch 135 | *sqlite 136 | *.jsonl 137 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PileOfLaw 2 | 3 | This is the codebase used for the experiments and data scraping tools used for gathering Pile of law. 4 | 5 | Note the Pile of Law dataset itself can be found here: https://huggingface.co/datasets/pile-of-law/pile-of-law 6 | 7 | Pretrained models (though they may be undertrained) can be found here: 8 | 9 | https://huggingface.co/pile-of-law/legalbert-large-1.7M-1 10 | 11 | https://huggingface.co/pile-of-law/legalbert-large-1.7M-2 12 | 13 | Note, the main model we reference in the paper is *-2. 14 | 15 | We can make intermediate checkpoints every 50k steps available on request, we do not upload them all as it is a significant amount of data storage. 16 | 17 | The EOIR privacy pretrained model is available at https://huggingface.co/pile-of-law/distilbert-base-uncased-finetuned-eoir_privacy 18 | 19 | All model cards and datacards are included in the associated repositories. Code for the paper is available as follows. 20 | 21 | ## Dataset Creation 22 | 23 | All of the tools used to scrape every subset of the data lives in the dataset_creation subfolder. 24 | 25 | ## Privacy Experiments 26 | 27 | The privacy experiments live under the privacy folder with a separate Readme explaining those sections. 28 | 29 | ## Toxicity Experiments 30 | 31 | The toxicity experiments live under the toxicity subfolder 32 | 33 | ## Pretraining and fine-tuning 34 | 35 | The pretraining processing scripts and fine-tuning scripts live under the pretraining folder. 36 | 37 | ## Data scrubbing 38 | 39 | We also examine all datasets for SSNs. While we log all private info, we found that most APIs had a significant number of false positives. We narrowly looked for SSNs, but did not encounter any. It is possible the filters we use are not robust, but similarly all datasets we use should already be pre-filtered for such sensitive information also. Scripts for this can be found in the scrub folder. 40 | 41 | 42 | ## Citations 43 | 44 | Please cite our work if you use any of the tools here and/or the data. 45 | 46 | ``` 47 | @article{hendersonkrass2022pileoflaw, 48 | title={Pile of Law: Learning Responsible Data Filtering from the Law and a 256GB Open-Source Legal Dataset}, 49 | author={Henderson*, Peter and Krass*, Mark and Zheng, Lucia and Guha, Neel and Manning, Christopher and Jurafsky, Dan and Ho, Daniel E}, 50 | year={2022} 51 | } 52 | ``` 53 | 54 | Some of the datasets in this work are transformed from prior work. Please cite these works as well if you use this dataset: 55 | 56 | ``` 57 | @inproceedings{borchmann-etal-2020-contract, 58 | title = "Contract Discovery: Dataset and a Few-Shot Semantic Retrieval Challenge with Competitive Baselines", 59 | author = "Borchmann, {\L}ukasz and 60 | Wisniewski, Dawid and 61 | Gretkowski, Andrzej and 62 | Kosmala, Izabela and 63 | Jurkiewicz, Dawid and 64 | Sza{\l}kiewicz, {\L}ukasz and 65 | Pa{\l}ka, Gabriela and 66 | Kaczmarek, Karol and 67 | Kaliska, Agnieszka and 68 | Grali{\'n}ski, Filip", 69 | booktitle = "Findings of the Association for Computational Linguistics: EMNLP 2020", 70 | month = nov, 71 | year = "2020", 72 | address = "Online", 73 | publisher = "Association for Computational Linguistics", 74 | url = "https://www.aclweb.org/anthology/2020.findings-emnlp.380", 75 | pages = "4254--4268", 76 | abstract = "We propose a new shared task of semantic retrieval from legal texts, in which a so-called contract discovery is to be performed {--} where legal clauses are extracted from documents, given a few examples of similar clauses from other legal acts. The task differs substantially from conventional NLI and shared tasks on legal information extraction (e.g., one has to identify text span instead of a single document, page, or paragraph). The specification of the proposed task is followed by an evaluation of multiple solutions within the unified framework proposed for this branch of methods. It is shown that state-of-the-art pretrained encoders fail to provide satisfactory results on the task proposed. In contrast, Language Model-based solutions perform better, especially when unsupervised fine-tuning is applied. Besides the ablation studies, we addressed questions regarding detection accuracy for relevant text fragments depending on the number of examples available. In addition to the dataset and reference results, LMs specialized in the legal domain were made publicly available.", 77 | } 78 | 79 | @data{T1/N1X6I4_2020, 80 | author = {Blair-Stanek, Andrew and Holzenberger, Nils and Van Durme, Benjamin}, 81 | publisher = {Johns Hopkins University Data Archive}, 82 | title = "{Tax Law NLP Resources}", 83 | year = {2020}, 84 | version = {V2}, 85 | doi = {10.7281/T1/N1X6I4}, 86 | url = {https://doi.org/10.7281/T1/N1X6I4} 87 | } 88 | 89 | @article{hendrycks2021cuad, 90 | title={CUAD: An Expert-Annotated NLP Dataset for Legal Contract Review}, 91 | author={Dan Hendrycks and Collin Burns and Anya Chen and Spencer Ball}, 92 | journal={arXiv preprint arXiv:2103.06268}, 93 | year={2021} 94 | } 95 | 96 | @inproceedings{koehn2005europarl, 97 | title={Europarl: A parallel corpus for statistical machine translation}, 98 | author={Koehn, Philipp and others}, 99 | booktitle={MT summit}, 100 | volume={5}, 101 | pages={79--86}, 102 | year={2005}, 103 | organization={Citeseer} 104 | } 105 | 106 | @article{DBLP:journals/corr/abs-1805-01217, 107 | author = {Marco Lippi and 108 | Przemyslaw Palka and 109 | Giuseppe Contissa and 110 | Francesca Lagioia and 111 | Hans{-}Wolfgang Micklitz and 112 | Giovanni Sartor and 113 | Paolo Torroni}, 114 | title = {{CLAUDETTE:} an Automated Detector of Potentially Unfair Clauses in 115 | Online Terms of Service}, 116 | journal = {CoRR}, 117 | volume = {abs/1805.01217}, 118 | year = {2018}, 119 | url = {http://arxiv.org/abs/1805.01217}, 120 | archivePrefix = {arXiv}, 121 | eprint = {1805.01217}, 122 | timestamp = {Mon, 13 Aug 2018 16:49:16 +0200}, 123 | biburl = {https://dblp.org/rec/bib/journals/corr/abs-1805-01217}, 124 | bibsource = {dblp computer science bibliography, https://dblp.org} 125 | } 126 | 127 | @article{ruggeri2021detecting, 128 | title={Detecting and explaining unfairness in consumer contracts through memory networks}, 129 | author={Ruggeri, Federico and Lagioia, Francesca and Lippi, Marco and Torroni, Paolo}, 130 | journal={Artificial Intelligence and Law}, 131 | pages={1--34}, 132 | year={2021}, 133 | publisher={Springer} 134 | } 135 | 136 | @inproceedings{10.1145/3462757.3466066, 137 | author = {Huang, Zihan and Low, Charles and Teng, Mengqiu and Zhang, Hongyi and Ho, Daniel E. and Krass, Mark S. and Grabmair, Matthias}, 138 | title = {Context-Aware Legal Citation Recommendation Using Deep Learning}, 139 | year = {2021}, 140 | isbn = {9781450385268}, 141 | publisher = {Association for Computing Machinery}, 142 | address = {New York, NY, USA}, 143 | url = {https://doi.org/10.1145/3462757.3466066}, 144 | doi = {10.1145/3462757.3466066}, 145 | abstract = {Lawyers and judges spend a large amount of time researching the proper legal authority 146 | to cite while drafting decisions. In this paper, we develop a citation recommendation 147 | tool that can help improve efficiency in the process of opinion drafting. We train 148 | four types of machine learning models, including a citation-list based method (collaborative 149 | filtering) and three context-based methods (text similarity, BiLSTM and RoBERTa classifiers). 150 | Our experiments show that leveraging local textual context improves recommendation, 151 | and that deep neural models achieve decent performance. We show that non-deep text-based 152 | methods benefit from access to structured case metadata, but deep models only benefit 153 | from such access when predicting from context of insufficient length. We also find 154 | that, even after extensive training, RoBERTa does not outperform a recurrent neural 155 | model, despite its benefits of pretraining. Our behavior analysis of the RoBERTa model 156 | further shows that predictive performance is stable across time and citation classes.}, 157 | booktitle = {Proceedings of the Eighteenth International Conference on Artificial Intelligence and Law}, 158 | pages = {79–88}, 159 | numpages = {10}, 160 | keywords = {neural natural language processing, legal opinion drafting, citation recommendation, legal text, citation normalization}, 161 | location = {S\~{a}o Paulo, Brazil}, 162 | series = {ICAIL '21} 163 | } 164 | ``` 165 | -------------------------------------------------------------------------------- /dataset_creation/.gitignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Breakend/PileOfLaw/e6f7b11ba52e0fbbd392ad00bf3cc1632ae710b6/dataset_creation/.gitignore -------------------------------------------------------------------------------- /dataset_creation/bar_outlines/scrape_bar_exam_outlines.py: -------------------------------------------------------------------------------- 1 | import os 2 | import requests 3 | import numpy as np 4 | import json 5 | from urllib.parse import urljoin 6 | from bs4 import BeautifulSoup 7 | import textract 8 | import random 9 | import datetime 10 | 11 | try: 12 | import lzma as xz 13 | except ImportError: 14 | import pylzma as xz 15 | 16 | # Hacky, but this is just for scraping, just uncomment the one you want. 17 | #url = "https://law.stanford.edu/office-of-student-affairs/bar-exam-information/" 18 | url = "https://adamshajnfeld.weebly.com/" 19 | 20 | 21 | headers = {'User-Agent': 'Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/56.0.2924.76 Safari/537.36'} # This is chrome, you can set whatever browser you like 22 | #If there is no such folder, the script will create one automatically 23 | folder_location = './cache/' 24 | if not os.path.exists(folder_location): 25 | os.mkdir(folder_location) 26 | 27 | def save_to_processed(train, val, source_name, out_path): 28 | if not os.path.exists(out_path): 29 | os.makedirs(out_path) 30 | tf = os.path.join(out_path, f"train.{source_name}.jsonl") 31 | with open(tf, mode='w', encoding='utf-8') as out_file: 32 | for line in train: 33 | out_file.write(json.dumps(line) + "\n") 34 | print(f"Written {len(train)} documents to {tf}") 35 | 36 | vf = os.path.join(out_path, f"validation.{source_name}.jsonl") 37 | with open(vf, mode='w', encoding='utf-8') as out_file: 38 | for line in val: 39 | out_file.write(json.dumps(line) + "\n") 40 | print(f"Written {len(val)} documents to {vf}") 41 | # now compress with lib 42 | print("compressing files...") 43 | with open(vf, 'rb') as f, open(vf+".xz", 'wb') as out: 44 | out.write(xz.compress(bytes(f.read()))) 45 | with open(tf, 'rb') as f, open(tf+".xz", 'wb') as out: 46 | out.write(xz.compress(bytes(f.read()))) 47 | print("compressed") 48 | response = requests.get(url, headers=headers) 49 | soup = BeautifulSoup(response.text, "html.parser") 50 | docs = [] 51 | for link in soup.select("a[href$='.doc']"): 52 | #Name the pdf files using the last portion of each link which are unique in this case 53 | filename = os.path.join(folder_location,link['href'].split('/')[-1]) 54 | if not os.path.exists(filename): 55 | with open(filename, 'wb') as f: 56 | f.write(requests.get(urljoin(url,link['href']), headers=headers).content) 57 | text = textract.process(filename) 58 | 59 | docs.append({ 60 | "url" : link['href'], 61 | "text" : str(text), 62 | "downloaded_timestamp" : datetime.date.today().strftime("%m-%d-%Y"), 63 | "created_timestamp" : "08-30-2006" if "stanford" in url else "02-01-2013" # though we don't have an exact date for when the docs were created, 64 | # the author graduated in 2006 and I assume took the bar then and 65 | # created these for the bar so I'm assuming 2006 66 | # if it's the adam website, he state that he took the bar in feb 2013 67 | }) 68 | for link in soup.select("a[href$='.docx']"): 69 | #Name the pdf files using the last portion of each link which are unique in this case 70 | filename = os.path.join(folder_location,link['href'].split('/')[-1]) 71 | if not os.path.exists(filename): 72 | with open(filename, 'wb') as f: 73 | f.write(requests.get(urljoin(url,link['href']), headers=headers).content) 74 | text = textract.process(filename) 75 | docs.append({ 76 | "url" : link['href'], 77 | "text" : str(text), 78 | "downloaded_timestamp" : datetime.date.today().strftime("%m-%d-%Y"), 79 | "created_timestamp" : "08-30-2006" if "stanford" in url else "02-01-2013" # though we don't have an exact date for when the docs were created, 80 | # the author graduated in 2006 and I assume took the bar then and 81 | # created these for the bar so I'm assuming 2006 82 | # if it's the adam website, he state that he took the bar in feb 2013 83 | }) 84 | 85 | random.seed(0) # important for shuffling 86 | rand_idx = list(range(len(docs))) 87 | random.shuffle(rand_idx) 88 | 89 | train_idx = rand_idx[:int(len(rand_idx)*0.75)] 90 | val_idx = rand_idx[int(len(rand_idx)*0.75):] 91 | 92 | train_docs = np.array(docs)[train_idx] 93 | val_docs = np.array(docs)[val_idx] 94 | 95 | if "stanford" in url: 96 | save_to_processed(train_docs, val_docs, "stanfordbarexamoutlines", "./cache/") 97 | elif "adam" in url: 98 | save_to_processed(train_docs, val_docs, "shajnfeldbarexamoutlines", "./cache/") 99 | -------------------------------------------------------------------------------- /dataset_creation/bva/process_bva.py: -------------------------------------------------------------------------------- 1 | # Processes BVA opinions (2001-2018) 2 | # Raw txt files downloaded through private link shared by authors of this paper: https://arxiv.org/abs/2106.10776 3 | # In the future, the files will be available here: https://drive.google.com/drive/folders/12lAd8Os7VFeqbTKi4wcqJqODjHIn0-yQ?usp=sharing 4 | 5 | import os 6 | import re 7 | import glob 8 | import json 9 | import tarfile 10 | import datetime 11 | import random 12 | from tqdm import tqdm 13 | from dateutil import parser 14 | 15 | 16 | # SOURCE URL 17 | URL = "https://drive.google.com/drive/folders/12lAd8Os7VFeqbTKi4wcqJqODjHIn0-yQ?usp=sharing" 18 | 19 | # DATA DIRS 20 | IN_DIR = "../../data/bva/raw" # path to BVA opinion tar files by year 21 | OUT_DIR = "../../data/bva/processed" # path to write data out to 22 | 23 | # Regex for date match 24 | date_regex = r'(?:\s*Decision\s+Date:\s+)(\d{1,2}/\d{1,2}/\d{1,2})' 25 | 26 | def save_to_file(data, out_dir, fname): 27 | if not os.path.exists(out_dir): 28 | os.makedirs(out_dir) 29 | fpath = os.path.join(out_dir, fname) 30 | with open(fpath, 'w') as out_file: 31 | for x in data: 32 | out_file.write(json.dumps(x) + "\n") 33 | print(f"Written {len(data)} to {fpath}") 34 | 35 | def main(): 36 | tar_files = glob.glob(os.path.join(IN_DIR, "*.tar.gz")) 37 | 38 | docs = [] 39 | for tar_file in tqdm(tar_files): 40 | print("Processing tar file:", tar_file) 41 | tar = tarfile.open(tar_file, 'r:gz') 42 | for member in tar.getmembers(): 43 | if ".txt" in member.name: 44 | f = tar.extractfile(member) 45 | content = f.read() 46 | # Original encoding is in latin1, we want to decode it as utf-8 47 | text = content.decode('latin1').encode('utf-8').decode('utf-8') 48 | 49 | # Extract creation date 50 | lines = text.splitlines() 51 | creation_date = "" 52 | match = None 53 | for line in lines: 54 | match = re.search(date_regex, line) 55 | # If matched date, break at current line and parse / reformat match 56 | if match: 57 | creation_date = parser.parse(match.group(1)).strftime("%m-%d-%Y") 58 | 59 | doc = { 60 | "url": URL, 61 | "created_timestamp": creation_date, 62 | "downloaded_timestamp": datetime.date.today().strftime("%m-%d-%Y"), 63 | "text": text 64 | } 65 | docs.append(doc) 66 | 67 | # Shuffle and split into train / validation 68 | random.seed(0) 69 | random.shuffle(docs) 70 | train = docs[:int(len(docs)*0.75)] 71 | validation = docs[int(len(docs)*0.75):] 72 | 73 | save_to_file(train, OUT_DIR, "train.bva.jsonl") 74 | save_to_file(validation, OUT_DIR, "validation.bva.jsonl") 75 | 76 | if __name__ == '__main__': 77 | main() -------------------------------------------------------------------------------- /dataset_creation/canadian_decisions/post_process.py: -------------------------------------------------------------------------------- 1 | 2 | from bs4 import BeautifulSoup 3 | import requests_cache 4 | import os 5 | import time 6 | import os 7 | from tqdm import tqdm 8 | import numpy as np 9 | import json 10 | import random 11 | import re 12 | import pickle 13 | import datetime 14 | try: 15 | import lzma as xz 16 | except ImportError: 17 | import pylzma as xz 18 | 19 | overwrite = True 20 | open_type = 'w' if overwrite else 'a' 21 | train_f = xz.open("./train.canadian_decisions.xz", open_type) 22 | val_f = xz.open("./validation.canadian_decisions.xz", open_type) 23 | 24 | with open("canada_cases.pickle", "rb") as f: 25 | _pickled = pickle.load(f) 26 | 27 | for key, value in _pickled.items(): 28 | if "The specific page has either moved or is no longer part" in value['text']: 29 | continue 30 | 31 | 32 | datapoint = { 33 | "text" : value['text'], 34 | "created_timestamp" : value['year'], 35 | "downloaded_timestamp" : datetime.date.today().strftime("%m-%d-%Y"), 36 | "url" : "https://www.bccourts.ca/search_judgments.aspx?obd={}&court=1#SearchTitle" if value["jdx"] == "bc" else "https://www.ontariocourts.ca/coa/decisions_main" 37 | } 38 | 39 | if random.random() > .75: 40 | val_f.write((json.dumps(datapoint) + "\n").encode("utf-8")) 41 | else: 42 | train_f.write((json.dumps(datapoint) + "\n").encode("utf-8")) 43 | 44 | train_f.close() 45 | val_f.close() 46 | -------------------------------------------------------------------------------- /dataset_creation/collect_atticus_contracts.py: -------------------------------------------------------------------------------- 1 | # Code for formatting unlabelled Atticus contract data from: https://github.com/TheAtticusProject/cuad 2 | 3 | import glob 4 | import os 5 | import random 6 | from tqdm import tqdm 7 | import datetime 8 | import json 9 | 10 | IN_DIR = "../../data/atticus_contract/contracts/" 11 | OUT_DIR = "../../data/atticus_contract/processed" 12 | 13 | 14 | def save_to_file(data, fpath): 15 | with open(fpath, "w") as out_file: 16 | for x in data: 17 | out_file.write(json.dumps(x) + "\n") 18 | print(f"Written {len(data)} to {fpath}") 19 | 20 | def main(): 21 | 22 | # load contracts 23 | files = glob.glob(os.path.join(IN_DIR, "*", "*.txt")) 24 | print(f"Collected {len(files)} contracts") 25 | docs = [] 26 | for f in files: 27 | text = "" 28 | with open(f) as in_file: 29 | for line in in_file: 30 | text = text + line 31 | 32 | doc = { 33 | "url": "https://github.com/TheAtticusProject/cuad", 34 | "created_timestamp": "", 35 | "timestamp": datetime.date.today().strftime("%m-%d-%Y"), 36 | "text": text 37 | } 38 | docs.append(doc) 39 | 40 | # shuffle and split into train / validation 41 | random.seed(0) 42 | random.shuffle(docs) 43 | train = docs[:int(len(docs)*0.75)] 44 | validation = docs[int(len(docs)*0.75):] 45 | 46 | save_to_file(train, os.path.join(OUT_DIR, "train.atticus_contracts.jsonl")) 47 | save_to_file(validation, os.path.join(OUT_DIR, "validation.atticus_contracts.jsonl")) 48 | 49 | 50 | 51 | if __name__ == "__main__": 52 | main() -------------------------------------------------------------------------------- /dataset_creation/congressional_hearings/scrape_congressional_hearings.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | 3 | import scrapelib 4 | import pytz 5 | import os 6 | import cachetools 7 | import textract 8 | import numpy as np 9 | import random 10 | import json 11 | import requests_cache 12 | try: 13 | import lzma as xz 14 | except ImportError: 15 | import pylzma as xz 16 | 17 | requests = requests_cache.CachedSession('casebriefscache') 18 | headers = {'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_10_1) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/39.0.2171.95 Safari/537.36'} 19 | headers['X-Api-Key'] = os.environ['API_KEY'] 20 | 21 | class GovInfo(scrapelib.Scraper): 22 | BASE_URL = 'https://api.govinfo.gov' 23 | 24 | def __init__(self, *args, **kwargs): 25 | super().__init__(*args, **kwargs) 26 | print(os.environ['API_KEY']) 27 | self.headers['X-Api-Key'] = os.environ['API_KEY'] 28 | 29 | 30 | def collections(self): 31 | endpoint = '/collections' 32 | response = requests.get(self.BASE_URL + endpoint, headers=headers) 33 | return response.json() 34 | 35 | 36 | def _format_time(self, dt): 37 | 38 | utc_time = dt.astimezone(pytz.utc) 39 | time_str = dt.strftime('%Y-%m-%dT%H:%M:%SZ') 40 | 41 | return time_str 42 | 43 | def congressional_hearings(self, congress=117): 44 | 45 | 46 | partial = f"/collections/CHRG/{self._format_time(datetime.datetime(1970, 1, 1, 0, 0, tzinfo=pytz.utc))}/" 47 | url_template = self.BASE_URL + partial + "{end_time}" 48 | end_time = datetime.datetime.now(pytz.utc) 49 | end_time_str = self._format_time(end_time) 50 | 51 | seen = cachetools.LRUCache(30) 52 | for page in self._pages(url_template, congress, end_time_str): 53 | for package in page['packages']: 54 | package_id = package['packageId'] 55 | 56 | if package_id in seen: 57 | continue 58 | else: 59 | # the LRUCache is like a dict, but all we care 60 | # about is whether we've seen this package 61 | # recently, so we just store None as the value 62 | # associated with the package_id key 63 | seen[package_id] = None 64 | 65 | response = requests.get(package['packageLink'], headers=headers) 66 | try: 67 | data = response.json() 68 | except: 69 | data = None 70 | 71 | yield data 72 | 73 | def _download_pdf(self, data): 74 | # if not "html" in data["download"]: 75 | # print(data["download"].keys()) 76 | # return 77 | if data is None or data["download"] is None: 78 | return None 79 | url = data["download"]["zipLink"] 80 | tag = url.split("/")[-2] 81 | url = f"https://www.govinfo.gov/content/pkg/{tag}/html/{tag}.htm" 82 | response = requests.get(url, headers=headers) 83 | 84 | if response.status_code != 200: 85 | return None 86 | 87 | text = str(response.content) 88 | 89 | # with open(f'cache/{tag}.pdf', 'wb') as f: 90 | # f.write(response.content) 91 | 92 | # text = str(textract.process(f'cache/{tag}.pdf', method="tesseract")) 93 | 94 | datapoint = { 95 | "text" : text, 96 | "created_timestamp" : data['dateIssued'], 97 | "downloaded_timestamp" : datetime.date.today().strftime("%m-%d-%Y"), 98 | "url" : url 99 | } 100 | return datapoint 101 | 102 | def _pages(self, url_template, congress, end_time): 103 | page_size = 100 104 | 105 | params = {'offset': 0, 106 | 'pageSize': page_size, 107 | "congress" : congress} 108 | 109 | #url = url_template 110 | url = url_template.format(end_time=end_time) 111 | try_count = 0 112 | responded = False 113 | while not responded: 114 | response = requests.get(url, params=params, headers=headers) 115 | if response.status_code != 200: 116 | responded = False 117 | time.sleep(random.random() * 5) 118 | try_count += 1 119 | if try_count >= 3: 120 | responded = True 121 | else: 122 | responded = True 123 | data = response.json() 124 | 125 | yield data 126 | 127 | while len(data['packages']) == page_size: 128 | 129 | # the API results are sorted in descending order by timestamp 130 | # so we can paginate through results by making the end_time 131 | # filter earlier and earlier 132 | earliest_timestamp = data['packages'][-1]['lastModified'] 133 | url = url_template.format(end_time=earliest_timestamp) 134 | 135 | response = requests.get(url, params=params, headers=headers) 136 | data = response.json() 137 | 138 | yield data 139 | 140 | 141 | scraper = GovInfo() 142 | 143 | # Prints out all the different types of collections available 144 | # in the govinfo API 145 | print(scraper.collections()) 146 | 147 | # Iterate through every congressional hearing 148 | # 149 | # For congressional hearings you need a specify a start 150 | # date time with a timezone 151 | # start_time = datetime.datetime(2020, 1, 1, 0, 0, tzinfo=pytz.utc) 152 | congresses = np.arange(89, 118)[::-1] 153 | 154 | #seen_urls = [] 155 | #if os.path.exists("./cache/train.congressional_hearings.xz"): 156 | # with xz.open("./cache/train.congressional_hearings.xz", 'r') as f: 157 | # import pdb; pdb.set_trace() 158 | # seen_urls.extend([x["url"] for x in f.readlines()]) 159 | # with xz.open("./cache/validation.congressional_hearings.xz", 'r') as f: 160 | # seen_urls.extend([x["url"] for x in f.readlines()]) 161 | 162 | # val_f = xz.open("./cache/validation.congressional_hearings.xz", 'r') 163 | collected_urls = [] 164 | i = 0 165 | overwrite = True 166 | open_type = 'w' if overwrite else 'a' 167 | train_f = xz.open("./cache/train.congressional_hearings.xz", open_type) 168 | val_f = xz.open("./cache/validation.congressional_hearings.xz", open_type) 169 | for congress in congresses: 170 | print(f"NOW GETTING CONGRESS {congress}") 171 | for hearing in scraper.congressional_hearings(congress): 172 | datapoint = scraper._download_pdf(hearing) 173 | if datapoint is None: 174 | print("No data for hearing") 175 | print(hearing) 176 | continue 177 | #print(datapoint["url"]) 178 | if datapoint["url"] in collected_urls: 179 | print("ALREADY SAW URL!!") 180 | collected_urls.append(datapoint["url"]) 181 | i += 1 182 | if i % 1000 == 0: 183 | print(i) 184 | # import pdb; pdb.set_trace() 185 | if random.random() > .75: 186 | val_f.write((json.dumps(datapoint) + "\n").encode("utf-8")) 187 | else: 188 | train_f.write((json.dumps(datapoint) + "\n").encode("utf-8")) 189 | -------------------------------------------------------------------------------- /dataset_creation/constitutions/scrape_constitutions.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | url= "https://www.constituteproject.org/service/constitutions?in_force=true" 4 | url_template = "https://www.constituteproject.org/constitution/{const_id}?lang=en" 5 | 6 | from bs4 import BeautifulSoup 7 | from bs4 import BeautifulSoup 8 | import requests_cache 9 | import os 10 | import time 11 | import os 12 | import textract 13 | import numpy as np 14 | import json 15 | import random 16 | import datetime 17 | requests = requests_cache.CachedSession('scotus') 18 | headers = {'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_10_1) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/39.0.2171.95 Safari/537.36'} 19 | try: 20 | import lzma as xz 21 | except ImportError: 22 | import pylzma as xz 23 | 24 | cache = "./cache" 25 | 26 | if not os.path.exists(cache): 27 | os.mkdir(cache) 28 | 29 | 30 | response = requests.get(url, headers=headers) 31 | ids = [r["id"] for r in response.json()] 32 | overwrite = True 33 | open_type = 'w' if overwrite else 'a' 34 | train_f = xz.open("./cache/train.constitutions.xz", open_type) 35 | val_f = xz.open("./cache/validation.constitutions.xz", open_type) 36 | 37 | for _id in ids: 38 | print(_id) 39 | url = url_template.format(const_id= _id) 40 | response = requests.get(url, headers=headers) 41 | if not response.from_cache: 42 | time.sleep(random.random()*2) 43 | soup = BeautifulSoup(response.content) 44 | constitution_text_sections = soup.find_all("div", {"class" : "constitution-content__copy"}) 45 | title = soup.find("h1", {"class" : "clearfix"}).get_text() 46 | text = title 47 | for section in constitution_text_sections: 48 | text += section.get_text() 49 | text += "\n" 50 | 51 | datapoint = { 52 | "text" : text, 53 | "created_timestamp" : "", 54 | "downloaded_timestamp" : datetime.date.today().strftime("%m-%d-%Y"), 55 | "url" : url 56 | } 57 | 58 | # This dataset is already heavily US biased, so it would be really weird not to train on the US 59 | # constitution 60 | if random.random() > .75 and "United_States" not in _id: 61 | val_f.write((json.dumps(datapoint) + "\n").encode("utf-8")) 62 | else: 63 | train_f.write((json.dumps(datapoint) + "\n").encode("utf-8")) 64 | 65 | -------------------------------------------------------------------------------- /dataset_creation/courtlistener/scrape_and_process_cl_data.py: -------------------------------------------------------------------------------- 1 | import wget 2 | import json 3 | import os 4 | import tarfile 5 | from tqdm import tqdm 6 | from pathlib import Path 7 | import io 8 | import json 9 | import tempfile 10 | import datetime 11 | import bs4 12 | import random 13 | import shutil 14 | 15 | 16 | try: 17 | import lzma as xz 18 | except ImportError: 19 | import pylzma as xz 20 | url = "https://www.courtlistener.com/api/bulk-data/opinions/all.tar" 21 | 22 | if not os.path.exists("./cache/all.tar"): 23 | import wget 24 | filename = wget.download(url, out="./cache/") 25 | 26 | 27 | 28 | def idempotent(x): 29 | return x 30 | 31 | 32 | def html2text(x): 33 | soup = bs4.BeautifulSoup(x, "lxml") 34 | return soup.get_text() 35 | 36 | field_order = [ 37 | ("plain_text", idempotent), 38 | ("html", html2text), 39 | ("html_lawbox", html2text), 40 | ("html_columbia", html2text), 41 | ("html_with_citations", html2text), 42 | ("xml_harvard", html2text) 43 | ] 44 | 45 | error_str = ( 46 | "Unable to extract the content from this file. Please try reading the original." 47 | ) 48 | 49 | 50 | 51 | def parse_json(item): 52 | """ From https://github.com/thoppe/The-Pile-FreeLaw/blob/master/P1_extract_text.py 53 | """ 54 | 55 | js = json.loads(item) 56 | 57 | text = None 58 | 59 | if "html" in js and js["html"] == error_str: 60 | return None 61 | 62 | for k, func in field_order: 63 | if k in js and isinstance(js[k], str) and len(js[k]): 64 | text = func(js[k]) 65 | 66 | if text is None: 67 | print(f"Skipping {item}, couldn't find text.") 68 | return None 69 | 70 | return { 71 | "url" : js['resource_uri'], 72 | "text" : text, 73 | "downloaded_timestamp" : datetime.date.today().strftime("%m-%d-%Y"), 74 | "created_timestamp" : datetime.datetime.strptime(js['date_created'].split("T")[0], '%Y-%m-%d').strftime("%m-%d-%Y") 75 | } 76 | 77 | train = 0 78 | val = 0 79 | 80 | with xz.open("./cache/train.courtlisteneropinions.xz", 'w') as train_f: 81 | with xz.open("./cache/validation.courtlisteneropinions.xz", 'w') as val_f: 82 | with tarfile.open("./cache/all.tar") as all_tar: 83 | for jxd in all_tar.getmembers(): 84 | with tarfile.open(fileobj=all_tar.extractfile(jxd)) as jxd_tar: 85 | for opinion in jxd_tar.getmembers(): 86 | datapoint = parse_json(jxd_tar.extractfile(opinion).read()) 87 | if random.random() > .75: 88 | val_f.write((json.dumps(datapoint) + "\n").encode("utf-8")) 89 | val += 1 90 | else: 91 | train_f.write((json.dumps(datapoint) + "\n").encode("utf-8")) 92 | train += 1 93 | 94 | total, used, free = shutil.disk_usage("/") 95 | 96 | if train % 10000 == 0: 97 | print(f"Have {train} documents and {val} validation documents!") 98 | print("Total: %d GiB" % (total // (2**30))) 99 | print("Used: %d GiB" % (used // (2**30))) 100 | print("Free: %d GiB" % (free // (2**30))) 101 | 102 | if (free // (2**30)) < 30: 103 | print("RUNNING OUT OF DISK!!!") 104 | if (free // (2**30)) < 25: 105 | print("Too little disk!!!") 106 | break 107 | 108 | print(f"Have {train} documents and {val} validation documents!") 109 | 110 | 111 | -------------------------------------------------------------------------------- /dataset_creation/courtlistener/scrape_and_process_cl_docket_data.py: -------------------------------------------------------------------------------- 1 | import wget 2 | import json 3 | import os 4 | import tarfile 5 | from tqdm import tqdm 6 | from pathlib import Path 7 | import io 8 | import json 9 | import tempfile 10 | import datetime 11 | import bs4 12 | import random 13 | import shutil 14 | 15 | 16 | try: 17 | import lzma as xz 18 | except ImportError: 19 | import pylzma as xz 20 | # url = "https://www.courtlistener.com/api/bulk-data/opinions/all.tar" 21 | 22 | # if not os.path.exists("./cache/all.tar"): 23 | # import wget 24 | # filename = wget.download(url, out="./cache/") 25 | 26 | 27 | 28 | def idempotent(x): 29 | return x 30 | 31 | 32 | def html2text(x): 33 | soup = bs4.BeautifulSoup(x, "lxml") 34 | return soup.get_text() 35 | 36 | field_order = [ 37 | ("plain_text", idempotent), 38 | ("html", html2text), 39 | ("html_lawbox", html2text), 40 | ("html_columbia", html2text), 41 | ("html_with_citations", html2text), 42 | ("xml_harvard", html2text) 43 | ] 44 | 45 | error_str = ( 46 | "Unable to extract the content from this file. Please try reading the original." 47 | ) 48 | 49 | import requests 50 | import time 51 | import random 52 | import datetime 53 | import json 54 | import os 55 | 56 | def requestJSON(url): 57 | while True: 58 | try: 59 | r = requests.get(url, headers={'Authorization': f'Token {os.environ["API_KEY"]}' }) 60 | if r.status_code != 200: 61 | print('error code', r.status_code) 62 | time.sleep(5) 63 | continue 64 | else: 65 | break 66 | except Exception as e: 67 | print(e) 68 | time.sleep(5) 69 | continue 70 | return r.json() 71 | 72 | 73 | next_page = "https://www.courtlistener.com/api/rest/v3/recap-documents/?is_available=true" 74 | 75 | val=0 76 | train=0 77 | from dateutil.relativedelta import * 78 | import datetime 79 | import datefinder 80 | 81 | if os.path.exists("./cache/cur_url.txt"): 82 | with open("./cache/cur_url.txt", "r") as f: 83 | next_page = f.read().strip() 84 | dates = list(datefinder.find_dates(next_page)) 85 | cur_month = dates[0] 86 | prev_month = dates[1] 87 | if cur_month < prev_month: 88 | tmp = prev_month 89 | prev_month = cur_month 90 | cur_month = tmp 91 | else: 92 | cur_month = datetime.datetime.now() 93 | prev_month = cur_month - relativedelta(days=3) 94 | next_page = f"https://www.courtlistener.com/api/rest/v3/docket-entries/?date_filed__lt={cur_month.strftime('%Y-%m-%d')}&date_filed__gt={prev_month.strftime('%Y-%m-%d')}&fields=date_filed%2Crecap_documents%2Cdescription&recap_documents__is_available=true" 95 | 96 | 97 | with xz.open("./cache/train.courtlistenerdocketentries.xz", 'a') as train_f: 98 | with xz.open("./cache/validation.courtlistenerdocketentries.xz", 'a') as val_f: 99 | while True: 100 | #print(cur_month.strftime('%Y-%m-%d')) 101 | if next_page is None: 102 | next_page = f"https://www.courtlistener.com/api/rest/v3/docket-entries/?date_filed__lt={cur_month.strftime('%Y-%m-%d')}&date_filed__gt={prev_month.strftime('%Y-%m-%d')}&fields=date_filed%2Crecap_documents%2Cdescription&recap_documents__is_available=true" 103 | while next_page is not None: 104 | print(next_page) 105 | js_data = requestJSON(next_page) 106 | if 'count' in js_data: 107 | print(js_data['count']) 108 | time.sleep(random.random()*3) 109 | next_page = js_data["next"] 110 | if next_page is not None: 111 | with open('./cache/cur_url.txt', 'w') as f: 112 | f.write(next_page) 113 | for docket_entry in js_data["results"]: 114 | for recap_data in docket_entry["recap_documents"]: 115 | if "plain_text" in recap_data and recap_data["plain_text"]: 116 | datapoint = { 117 | "url" : recap_data['resource_uri'], 118 | "text" : recap_data["plain_text"], 119 | "downloaded_timestamp" : datetime.date.today().strftime("%m-%d-%Y"), 120 | "created_timestamp" : docket_entry['date_filed'] 121 | } 122 | if random.random() > .75: 123 | val_f.write((json.dumps(datapoint) + "\n").encode("utf-8")) 124 | val += 1 125 | else: 126 | train_f.write((json.dumps(datapoint) + "\n").encode("utf-8")) 127 | train += 1 128 | if train % 5000 == 0: 129 | print(f"Have {train} documents and {val} validation documents!") 130 | cur_month = prev_month 131 | prev_month = cur_month - relativedelta(days=3) 132 | 133 | 134 | -------------------------------------------------------------------------------- /dataset_creation/creative_commons_casebooks/scrape_cc_casebooks.py: -------------------------------------------------------------------------------- 1 | import zipfile 2 | import os 3 | import datetime 4 | import textract 5 | import random 6 | import json 7 | try: 8 | import lzma as xz 9 | except ImportError: 10 | import pylzma as xz 11 | overwrite = True 12 | open_type = 'w' if overwrite else 'a' 13 | train_f = xz.open("./cache/train.cc_casebooks.xz", open_type) 14 | val_f = xz.open("./cache/validation.cc_casebooks.xz", open_type) 15 | # ASSUMES THAT YOUVE DOWNLOADED THE TEXTBOOKS INTO THE CACHE AND THAT THEY'RE ALL CC LICENSE 16 | docs = [] 17 | for path, subdirs, files in os.walk("./cache/"): 18 | for name in files: 19 | if not name.endswith(".pdf"): 20 | continue 21 | text = str(textract.process(os.path.join(path, name))) 22 | datapoint = { 23 | "text" : text, 24 | "created_timestamp" : "", 25 | "downloaded_timestamp" : datetime.date.today().strftime("%m-%d-%Y"), 26 | "url" : "https://open.umn.edu/opentextbooks/" 27 | } 28 | 29 | if random.random() > .75: 30 | val_f.write((json.dumps(datapoint) + "\n").encode("utf-8")) 31 | else: 32 | train_f.write((json.dumps(datapoint) + "\n").encode("utf-8")) -------------------------------------------------------------------------------- /dataset_creation/creditcardcfpb/scrape_cfbp_cc_agreements.py: -------------------------------------------------------------------------------- 1 | import zipfile 2 | import os 3 | import datetime 4 | import textract 5 | import random 6 | import json 7 | try: 8 | import lzma as xz 9 | except ImportError: 10 | import pylzma as xz 11 | url = "https://files.consumerfinance.gov/a/assets/Credit_Card_Agreements_2021_Q2.zip" 12 | 13 | cache_path = "./cache/Credit_Card_Agreements_2021_Q2.zip" 14 | if not os.path.exists(cache_path): 15 | import wget 16 | filename = wget.download(url, out="./cache/") 17 | 18 | with zipfile.ZipFile(cache_path, 'r') as zip_ref: 19 | zip_ref.extractall("./cache/") 20 | 21 | overwrite = True 22 | open_type = 'w' if overwrite else 'a' 23 | train_f = xz.open("./cache/train.cfpb_cc.xz", open_type) 24 | val_f = xz.open("./cache/validation.cfpb_cc.xz", open_type) 25 | 26 | docs = [] 27 | for path, subdirs, files in os.walk("./cache/"): 28 | for name in files: 29 | if not name.endswith(".pdf"): 30 | continue 31 | text = str(textract.process(os.path.join(path, name))) 32 | datapoint = { 33 | "text" : text, 34 | "created_timestamp" : "2021", 35 | "downloaded_timestamp" : datetime.date.today().strftime("%m-%d-%Y"), 36 | "url" : url 37 | } 38 | 39 | if random.random() > .75: 40 | val_f.write((json.dumps(datapoint) + "\n").encode("utf-8")) 41 | else: 42 | train_f.write((json.dumps(datapoint) + "\n").encode("utf-8")) -------------------------------------------------------------------------------- /dataset_creation/dol_ecab/scrape_dol_ecab.py: -------------------------------------------------------------------------------- 1 | from bs4 import BeautifulSoup 2 | import requests_cache 3 | import os 4 | import time 5 | import os 6 | import textract 7 | import numpy as np 8 | import json 9 | import random 10 | import datetime 11 | import re 12 | import pandas as pd 13 | 14 | try: 15 | import lzma as xz 16 | except ImportError: 17 | import pylzma as xz 18 | 19 | cache = "./cache" 20 | 21 | requests = requests_cache.CachedSession('scotus') 22 | headers = {'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_10_1) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/39.0.2171.95 Safari/537.36'} 23 | 24 | if not os.path.exists(cache): 25 | os.mkdir(cache) 26 | 27 | docs = [] 28 | base_url = "https://www.dol.gov/agencies/ecab/decisions/{year}/{month}" 29 | pdf_url = "https://www.dol.gov/sites/dolgov/files/ecab/decisions/{year}/{month}/{tag}.pdf" 30 | 31 | docs_for_pseudonyms = [] 32 | 33 | num_pseudonyms = 0 34 | 35 | for vol in range(2007,2023): 36 | for month in ["Jan", "Feb", "Mar", "Apr", "May", "Jun", "Jul", "Aug", "Sep", "Oct", "Nov", "Dec"]: 37 | url = base_url.format(year=vol, month=month) 38 | # page2 = requests.get(url, headers=headers) 39 | # soup = BeautifulSoup(page2.text,"lxml") 40 | converters = {c: lambda x: str(x) for c in range(3)} 41 | try: 42 | table = pd.read_html(url, converters=converters) 43 | except: 44 | print(f"Skipping {month} {vol} b/c table formatting issue.") 45 | 46 | for row in table[0].iterrows(): 47 | tag, date, casename = row[1].values 48 | 49 | link = pdf_url.format(year=vol, month=month, tag=tag) 50 | 51 | print(tag) 52 | if not os.path.exists(f'{cache}/{tag}.pdf'): 53 | print(link) 54 | pdf = requests.get(link, headers=headers) 55 | 56 | with open(f'{cache}/{tag}.pdf', 'wb') as f: 57 | f.write(pdf.content) 58 | 59 | try: 60 | text = textract.process(f'{cache}/{tag}.pdf', encoding='utf-8') 61 | text = text.decode("utf8") 62 | except: 63 | print(f"Skipping {tag}!") 64 | continue 65 | 66 | try: 67 | datetime_object = datetime.datetime.strptime(date, '%B %d, %Y') 68 | except: 69 | try: 70 | datetime_object = datetime.datetime.strptime(date, '%B%d, %Y') 71 | except: 72 | date = " ".join(date.split(" ")[:-1]) 73 | try: 74 | datetime_object = datetime.datetime.strptime(date, '%B %d, %Y') 75 | except: 76 | continue 77 | timestamp = datetime_object.strftime("%m-%d-%Y") 78 | 79 | if len(text) < 100: 80 | import pdb; pdb.set_trace() 81 | 82 | docs.append({ 83 | "text" : text, 84 | "created_timestamp" : timestamp, 85 | "downloaded_timestamp" : datetime.date.today().strftime("%m-%d-%Y"), 86 | "url" : link 87 | }) 88 | 89 | 90 | def save_to_processed(train, val, source_name, out_path): 91 | if not os.path.exists(out_path): 92 | os.makedirs(out_path) 93 | tf = os.path.join(out_path, f"train.{source_name}.jsonl") 94 | with open(tf, mode='w', encoding='utf-8') as out_file: 95 | for line in train: 96 | out_file.write(json.dumps(line) + "\n") 97 | print(f"Written {len(train)} documents to {tf}") 98 | 99 | vf = os.path.join(out_path, f"validation.{source_name}.jsonl") 100 | with open(vf, mode='w', encoding='utf-8') as out_file: 101 | for line in val: 102 | out_file.write(json.dumps(line) + "\n") 103 | print(f"Written {len(val)} documents to {vf}") 104 | 105 | # now compress with lib 106 | print("compressing files...") 107 | with open(vf, 'rb') as f, open(vf+".xz", 'wb') as out: 108 | out.write(xz.compress(bytes(f.read()))) 109 | with open(tf, 'rb') as f, open(tf+".xz", 'wb') as out: 110 | out.write(xz.compress(bytes(f.read()))) 111 | print("compressed") 112 | 113 | random.seed(0) # important for shuffling 114 | rand_idx = list(range(len(docs))) 115 | random.shuffle(rand_idx) 116 | 117 | train_idx = rand_idx[:int(len(rand_idx)*0.75)] 118 | val_idx = rand_idx[int(len(rand_idx)*0.75):] 119 | 120 | train_docs = np.array(docs)[train_idx] 121 | val_docs = np.array(docs)[val_idx] 122 | 123 | save_to_processed(train_docs, val_docs, "dol_ecab", "./cache/") 124 | -------------------------------------------------------------------------------- /dataset_creation/echr/scrape_echr.py: -------------------------------------------------------------------------------- 1 | import zipfile 2 | import os 3 | import datetime 4 | try: 5 | import lzma as xz 6 | except ImportError: 7 | import pylzma as xz 8 | url = "https://archive.org/download/ECHR-ACL2019/ECHR_Dataset.zip" 9 | 10 | cache_path = "./cache/ECHR_Dataset.zip" 11 | if not os.path.exists(cache_path): 12 | import wget 13 | filename = wget.download(url, out="./cache/") 14 | 15 | with zipfile.ZipFile(cache_path, 'r') as zip_ref: 16 | zip_ref.extractall("./cache/") 17 | 18 | 19 | overwrite = True 20 | open_type = 'w' if overwrite else 'a' 21 | train_f = xz.open("./cache/train.echr.xz", open_type) 22 | val_f = xz.open("./cache/validation.echr.xz", open_type) 23 | import os, json 24 | 25 | path_to_json = './cache/EN_train' 26 | json_files = [pos_json for pos_json in os.listdir(path_to_json) if pos_json.endswith('.json')] 27 | 28 | for json_file in json_files: 29 | with open('./cache/EN_train/' + json_file, "r") as f: 30 | loaded = json.loads(f.read()) 31 | 32 | blocklist = ["ITEMID"] 33 | text = "" 34 | for key, val in loaded.items(): 35 | if val != "" and val is not None: 36 | if not isinstance(val, list): 37 | text += f"{key}: {val}\n" 38 | else: 39 | if len(val) > 0: 40 | joined = '\n'.join(val) 41 | text += f"{key}: {joined}\n" 42 | datapoint = { 43 | "url" : url, 44 | "text" : text, 45 | "downloaded_timestamp" : datetime.date.today().strftime("%m-%d-%Y"), 46 | "created_timestamp" : loaded["DATE"] 47 | } 48 | 49 | train_f.write((json.dumps(datapoint) + "\n").encode("utf-8")) 50 | 51 | 52 | path_to_json = 'cache/EN_dev' 53 | json_files = [pos_json for pos_json in os.listdir(path_to_json) if pos_json.endswith('.json')] 54 | 55 | for json_file in json_files: 56 | with open( 'cache/EN_dev/'+ json_file, "r") as f: 57 | loaded = json.loads(f.read()) 58 | 59 | blocklist = ["ITEMID"] 60 | text = "" 61 | for key, val in loaded.items(): 62 | if val != "" and val is not None: 63 | if not isinstance(val, list): 64 | text += f"{key}: {val}\n" 65 | else: 66 | if len(val) > 0: 67 | joined= '\n'.join(val) 68 | text += f"{key}: {joined}\n" 69 | datapoint = { 70 | "url" : url, 71 | "text" : text, 72 | "downloaded_timestamp" : datetime.date.today().strftime("%m-%d-%Y"), 73 | "created_timestamp" : loaded["DATE"] 74 | } 75 | 76 | val_f.write((json.dumps(datapoint) + "\n").encode("utf-8")) 77 | 78 | -------------------------------------------------------------------------------- /dataset_creation/edgar_contracts/process_edgar_contracts_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import datetime 3 | import random 4 | import numpy as np 5 | import json 6 | 7 | 8 | try: 9 | import lzma as xz 10 | except ImportError: 11 | import pylzma as xz 12 | url = "https://applica-public.s3-eu-west-1.amazonaws.com/contract-discovery/edgar.txt.xz" 13 | 14 | if not os.path.exists("./cache/edgar.txt.xz"): 15 | import wget 16 | filename = wget.download(url, out="./cache/") 17 | 18 | def save_to_processed(train, val, source_name, out_path): 19 | if not os.path.exists(out_path): 20 | os.makedirs(out_path) 21 | tf = os.path.join(out_path, f"train.{source_name}.jsonl") 22 | with open(tf, mode='w', encoding='utf-8') as out_file: 23 | for line in train: 24 | out_file.write(json.dumps(line) + "\n") 25 | print(f"Written {len(train)} documents to {tf}") 26 | 27 | vf = os.path.join(out_path, f"validation.{source_name}.jsonl") 28 | with open(vf, mode='w', encoding='utf-8') as out_file: 29 | for line in val: 30 | out_file.write(json.dumps(line) + "\n") 31 | print(f"Written {len(val)} documents to {vf}") 32 | # now compress with lib 33 | print("compressing files...") 34 | with open(vf, 'rb') as f, open(vf+".xz", 'wb') as out: 35 | out.write(xz.compress(bytes(f.read()))) 36 | with open(tf, 'rb') as f, open(tf+".xz", 'wb') as out: 37 | out.write(xz.compress(bytes(f.read()))) 38 | print("compressed") 39 | 40 | docs = [] 41 | 42 | with xz.open('./cache/edgar.txt.xz', mode='rt') as f: 43 | for line in f: 44 | docs.append({ 45 | "url" : url, 46 | "text" : line, 47 | "downloaded_timestamp" : datetime.date.today().strftime("%m-%d-%Y"), 48 | "created_timestamp" : "" # Unfortunately the dataset that originally scraped the material didn't keep the date. 49 | }) 50 | 51 | random.seed(0) # important for shuffling 52 | rand_idx = list(range(len(docs))) 53 | random.shuffle(rand_idx) 54 | 55 | train_idx = rand_idx[:int(len(rand_idx)*0.75)] 56 | val_idx = rand_idx[int(len(rand_idx)*0.75):] 57 | 58 | train_docs = np.array(docs)[train_idx] 59 | val_docs = np.array(docs)[val_idx] 60 | 61 | save_to_processed(train_docs, val_docs, "edgar", "./cache/") 62 | -------------------------------------------------------------------------------- /dataset_creation/eoir/scrape_eoir.py: -------------------------------------------------------------------------------- 1 | from bs4 import BeautifulSoup 2 | import requests_cache 3 | import os 4 | import time 5 | import os 6 | import textract 7 | import numpy as np 8 | import json 9 | import random 10 | import datetime 11 | import re 12 | 13 | try: 14 | import lzma as xz 15 | except ImportError: 16 | import pylzma as xz 17 | 18 | cache = "./cache" 19 | 20 | requests = requests_cache.CachedSession('scotus') 21 | headers = {'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_10_1) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/39.0.2171.95 Safari/537.36'} 22 | 23 | if not os.path.exists(cache): 24 | os.mkdir(cache) 25 | 26 | docs = [] 27 | base_url_27 = "https://www.justice.gov/eoir/volume-{vol:02}" 28 | base_url_26 = "https://www.justice.gov/eoir/precedent-decisions-volume-{vol:02}" 29 | 30 | docs_for_pseudonyms = [] 31 | 32 | num_pseudonyms = 0 33 | 34 | for vol in range(8,29): 35 | if vol >= 27: 36 | url = base_url_27.format(vol=vol) 37 | else: 38 | url = base_url_26.format(vol=vol) 39 | page2 = requests.get(url, headers=headers) 40 | soup = BeautifulSoup(page2.text,"lxml") 41 | g_dates = soup.findAll("td", attrs={'class': None, "colspan" : None}) 42 | g_data = soup.findAll("td", {"class" : "rteright"}) 43 | 44 | for data, date in zip(g_data, g_dates): 45 | # data = data.find("a") 46 | if "justice.gov" in data.find("a")["href"]: 47 | link = data.find("a")["href"] 48 | else: 49 | link = "https://www.justice.gov/" + data.find("a")["href"] 50 | if link.endswith(".pdf"): 51 | tag = link.split("/")[-1].replace(".pdf", "") 52 | else: 53 | tag = link.split("/")[-2] 54 | tag += "__" + str(vol) 55 | print(tag) 56 | if not os.path.exists(f'{cache}/{tag}.pdf'): 57 | print(link) 58 | pdf = requests.get(link) 59 | 60 | with open(f'{cache}/{tag}.pdf', 'wb') as f: 61 | f.write(pdf.content) 62 | 63 | text = textract.process(f'{cache}/{tag}.pdf', encoding='utf-8') 64 | text = text.decode("utf8") 65 | 66 | try: 67 | name = date.find("strong").text.replace(",", "") 68 | except: 69 | try: 70 | name = date.find("b").text.replace(",", "") 71 | except: 72 | import pdb; pdb.set_trace() 73 | 74 | print(name) 75 | 76 | def check_pseudo(ns): 77 | ns = ns.replace("et al.", "").strip() 78 | ns = ns.replace("‑", "-") 79 | if ns == "DEF-" or ns == "D'O-" or ns == "DEN-" or ns == "DEG" or ns == "DE M-" or ns == "D-S- INC." or ns == "DIP-": 80 | return True 81 | for n in re.split('(&|and|AND)',ns): 82 | n = n.strip() 83 | n = n.replace(" ", "") 84 | created_pseudo = "-".join(n.replace("-", "")) 85 | if created_pseudo == n or created_pseudo + "-" == n: 86 | return True 87 | created_pseudo = ".".join(n.replace(".", "")) 88 | if created_pseudo == n: 89 | return True 90 | return False 91 | 92 | is_pseudonym = check_pseudo(name) 93 | print(is_pseudonym) 94 | # if not is_pseudonym: 95 | # import pdb; pdb.set_trace() 96 | 97 | if is_pseudonym: 98 | num_pseudonyms +=1 99 | 100 | issuance_date = date.text.split("(")[-1].split(" ")[-1] 101 | issuance_date = issuance_date.replace(")", "").strip() 102 | print(issuance_date) 103 | if issuance_date == "": 104 | import pdb; pdb.set_trace() 105 | 106 | if len(text) < 100: 107 | import pdb; pdb.set_trace() 108 | 109 | docs.append({ 110 | "text" : text, 111 | "created_timestamp" : issuance_date, 112 | "downloaded_timestamp" : datetime.date.today().strftime("%m-%d-%Y"), 113 | "url" : link 114 | }) 115 | 116 | docs_for_pseudonyms.append({ 117 | "text" : text, 118 | "created_timestamp" : issuance_date, 119 | "downloaded_timestamp" : datetime.date.today().strftime("%m-%d-%Y"), 120 | "url" : link, 121 | "name" : name, 122 | "is_pseudonym" : is_pseudonym 123 | }) 124 | 125 | 126 | def save_to_processed(train, val, source_name, out_path): 127 | if not os.path.exists(out_path): 128 | os.makedirs(out_path) 129 | tf = os.path.join(out_path, f"train.{source_name}.jsonl") 130 | with open(tf, mode='w', encoding='utf-8') as out_file: 131 | for line in train: 132 | out_file.write(json.dumps(line) + "\n") 133 | print(f"Written {len(train)} documents to {tf}") 134 | 135 | vf = os.path.join(out_path, f"validation.{source_name}.jsonl") 136 | with open(vf, mode='w', encoding='utf-8') as out_file: 137 | for line in val: 138 | out_file.write(json.dumps(line) + "\n") 139 | print(f"Written {len(val)} documents to {vf}") 140 | 141 | # now compress with lib 142 | print("compressing files...") 143 | with open(vf, 'rb') as f, open(vf+".xz", 'wb') as out: 144 | out.write(xz.compress(bytes(f.read()))) 145 | with open(tf, 'rb') as f, open(tf+".xz", 'wb') as out: 146 | out.write(xz.compress(bytes(f.read()))) 147 | print("compressed") 148 | 149 | random.seed(0) # important for shuffling 150 | rand_idx = list(range(len(docs))) 151 | random.shuffle(rand_idx) 152 | 153 | train_idx = rand_idx[:int(len(rand_idx)*0.75)] 154 | val_idx = rand_idx[int(len(rand_idx)*0.75):] 155 | 156 | train_docs = np.array(docs)[train_idx] 157 | val_docs = np.array(docs)[val_idx] 158 | 159 | save_to_processed(train_docs, val_docs, "eoir", "./cache/") 160 | 161 | save_to_processed(docs_for_pseudonyms, [], "eoir_pseudonym", "./cache/") 162 | 163 | print("NUM PSEUDOI") 164 | print(num_pseudonyms) 165 | -------------------------------------------------------------------------------- /dataset_creation/euro_parl.py: -------------------------------------------------------------------------------- 1 | # Processes data from the European Parliament Proceedings Parallel Corpus 1996-2011 from 2 | # https://www.statmt.org/europarl/. We only use the English data. We strip out all tags. 3 | 4 | from bs4 import BeautifulSoup 5 | import os 6 | import glob 7 | import random 8 | from tqdm import tqdm 9 | import json 10 | import datetime 11 | 12 | URL = "https://www.statmt.org/europarl/" 13 | DATA_DIR = "../../data/europarl" 14 | OUT_DIR = "../../data/europarl/processed" 15 | 16 | def save_to_file(data, fpath): 17 | with open(fpath, "w") as out_file: 18 | for x in data: 19 | out_file.write(json.dumps(x) + "\n") 20 | print(f"Written {len(data)} to {fpath}") 21 | 22 | def process_file(f): 23 | doc = [] 24 | with open(f, "r") as in_file: 25 | for line in in_file: 26 | line = line.strip() 27 | if "<" in line or len(line) == 0: 28 | continue 29 | doc.append(line) 30 | return "\n".join(doc) 31 | 32 | 33 | def main(): 34 | 35 | # load data 36 | in_glob = os.path.join(DATA_DIR, "txt", "en", "*.txt") 37 | files = glob.glob(in_glob) 38 | print(f"Found {len(files)} files.") 39 | 40 | # parse files 41 | docs = [] 42 | for f in tqdm(files): 43 | text = process_file(f) 44 | doc = { 45 | "url": URL, 46 | "created_timestamp" : "", # We don't atually know individual 47 | "downloaded_timestamp" : datetime.date.today().strftime("%m-%d-%Y"), 48 | "text": text 49 | } 50 | docs.append(doc) 51 | 52 | # shuffle and split into train / validation 53 | random.seed(0) 54 | random.shuffle(docs) 55 | train = docs[:int(len(docs)*0.75)] 56 | validation = docs[int(len(docs)*0.75):] 57 | 58 | save_to_file(train, os.path.join(OUT_DIR, "train.euro_parl.jsonl")) 59 | save_to_file(validation, os.path.join(OUT_DIR, "validation.euro_parl.jsonl")) 60 | 61 | 62 | 63 | 64 | if __name__ == "__main__": 65 | main() -------------------------------------------------------------------------------- /dataset_creation/eurolex/scrape_eurolex.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import random 3 | import json 4 | import datetime 5 | try: 6 | import lzma as xz 7 | except ImportError: 8 | import pylzma as xz 9 | 10 | lex = pd.read_csv("./cache/EurLex_all.csv") 11 | 12 | overwrite = True 13 | open_type = 'w' if overwrite else 'a' 14 | train_f = xz.open("./cache/train.eurlex.xz", open_type) 15 | val_f = xz.open("./cache/validation.eurlex.xz", open_type) 16 | 17 | for act_name, Act_type, subject_matter, date_publication, text in zip(lex["Act_name"], lex["Act_type"], lex["Subject_matter"], lex["Date_publication"], lex["act_raw_text"]): 18 | 19 | datapoint = { 20 | "text" : f"Name: {act_name}\n Type: {Act_type}\n Subject Matter: {subject_matter}\n Date Published: {date_publication}\n\n {text}", 21 | "created_timestamp" : date_publication, 22 | "downloaded_timestamp" : datetime.date.today().strftime("%m-%d-%Y"), 23 | "url" : "https://dataverse.harvard.edu/dataset.xhtml?persistentId=doi:10.7910/DVN/0EGYWY" 24 | } 25 | if random.random() > .75: 26 | val_f.write((json.dumps(datapoint) + "\n").encode("utf-8")) 27 | else: 28 | train_f.write((json.dumps(datapoint) + "\n").encode("utf-8")) -------------------------------------------------------------------------------- /dataset_creation/federal_register/process_federal_register.py: -------------------------------------------------------------------------------- 1 | # Processes federal register proposed rules (2000 - 2021) 2 | # Bulk data pulled from: https://www.govinfo.gov/bulkdata/xml/FR 3 | 4 | import os 5 | import json 6 | import datetime 7 | import random 8 | from urllib.request import Request, urlopen 9 | # xpath only available in lxml etree, not ElementTree 10 | from lxml import etree 11 | from tqdm import tqdm 12 | from dateutil import parser 13 | 14 | # BASE URL (bulk data API endpoint) 15 | BASE_URL = "https://www.govinfo.gov/bulkdata/xml/FR" 16 | 17 | # DATA DIRS 18 | OUT_DIR = "../../data/federal_register/processed" # path to write data out to 19 | 20 | # Request variables 21 | headers = {'User-Agent': 'Mozilla/5.0', 'Accept': 'application/xml'} 22 | 23 | 24 | def save_to_file(data, out_dir, fname): 25 | if not os.path.exists(out_dir): 26 | os.makedirs(out_dir) 27 | fpath = os.path.join(out_dir, fname) 28 | with open(fpath, 'w') as out_file: 29 | for x in data: 30 | out_file.write(json.dumps(x) + "\n") 31 | print(f"Written {len(data)} to {fpath}") 32 | 33 | 34 | def request_raw_data(): 35 | urls = [BASE_URL] 36 | xmls = [] 37 | while len(urls) > 0: 38 | next_urls = [] 39 | for url in urls: 40 | print(url) 41 | request = Request(url, headers=headers) 42 | with urlopen(request) as response: 43 | root = etree.fromstring(response.read()) 44 | elems = root.xpath("*/file[folder='true' and name!='resources']") 45 | if len(elems) > 0: 46 | for e in elems: 47 | next_url = e.find("link").text 48 | next_urls.append(next_url) 49 | else: 50 | elems = root.xpath("*/file[mimeType='application/xml']") 51 | for e in elems: 52 | xml_url = e.find("link").text 53 | request = Request(xml_url, headers=headers) 54 | with urlopen(request) as response: 55 | xml = etree.fromstring(response.read()) 56 | # Add tuple of xml_url, xml Element instance 57 | xmls.append((xml_url, xml)) 58 | urls = next_urls 59 | 60 | return xmls 61 | 62 | def extract_rule_docs(xmls): 63 | docs = [] 64 | for (xml_url, xml) in tqdm(xmls): 65 | print(xml_url) 66 | date = xml.find("DATE").text 67 | creation_date = "" 68 | try: 69 | creation_date = parser.parse(date).strftime("%m-%d-%Y") 70 | except: 71 | pass 72 | 73 | proposed_rules = xml.xpath("PRORULES/PRORULE") 74 | for rule in proposed_rules: 75 | # In Python 3, use encoding='unicode' 76 | # In Python 2, use encoding='utf-8' and decode 77 | all_text = etree.tostring(rule, encoding='unicode', method='text') 78 | 79 | doc = { 80 | "url": xml_url, 81 | "created_timestamp": creation_date, 82 | "downloaded_timestamp": datetime.date.today().strftime("%m-%d-%Y"), 83 | "text": all_text 84 | } 85 | docs.append(doc) 86 | 87 | return docs 88 | 89 | 90 | def main(): 91 | # Request raw data directly using bulk data API 92 | xmls = request_raw_data() 93 | docs = extract_rule_docs(xmls) 94 | 95 | # Shuffle and split into train / validation 96 | random.seed(0) 97 | random.shuffle(docs) 98 | train = docs[:int(len(docs)*0.75)] 99 | validation = docs[int(len(docs)*0.75):] 100 | 101 | save_to_file(train, OUT_DIR, "train.federal_register.jsonl") 102 | save_to_file(validation, OUT_DIR, "validation.federal_register.jsonl") 103 | 104 | 105 | if __name__ == '__main__': 106 | main() -------------------------------------------------------------------------------- /dataset_creation/federal_rules/frcp.py: -------------------------------------------------------------------------------- 1 | import requests 2 | from bs4 import BeautifulSoup 3 | import datetime 4 | import random 5 | import numpy as np 6 | import lzma as xz 7 | import json 8 | 9 | def scrape_rule(url_val): 10 | page = requests.get(url_val) 11 | soup = BeautifulSoup(page.content, "html.parser") 12 | rule_content = soup.find(id="content1") 13 | if rule_content == None: 14 | return "" 15 | remove_text = rule_content.find(id="book-navigation-4940065") 16 | add_title = soup.find(id="page-title") 17 | # removing the bottom navigation titles that get scraped with the rule body 18 | stop_idx = rule_content.text.find(remove_text.text) 19 | # appending the name and editd rule content 20 | final_rule = add_title.text + rule_content.text[:stop_idx].strip() 21 | return final_rule 22 | 23 | # saving the train / val files 24 | def save_final_data(state_json_data, final_path): 25 | with xz.open(final_path, 'w') as state_data: 26 | for state_doc in state_json_data: 27 | state_data.write((json.dumps(state_doc) + "\n").encode("utf-8")) 28 | 29 | basic_url_val = "https://www.law.cornell.edu/rules/frcp/rule_" 30 | rule_idx = [str(x + 1) for x in list(range(86))] 31 | rule_idx.extend(["A", "B", "C", "D", "E", "F", "G"]) 32 | 33 | # get the number of rules in an article 34 | docs = [] 35 | for rule in rule_idx: 36 | full_rule_url = basic_url_val + rule 37 | rule_content = scrape_rule(full_rule_url) 38 | if rule_content == "": 39 | continue 40 | docs.append({ 41 | "url" : full_rule_url, 42 | "text" : rule_content, 43 | "downloaded_timestamp" : datetime.date.today().strftime("%m-%d-%Y"), 44 | "created_timestamp" : "" 45 | }) 46 | 47 | # train / test split from process_tax_corpus_jhu.py 48 | random.seed(0) 49 | rand_idx = list(range(len(docs))) 50 | random.shuffle(rand_idx) 51 | 52 | train_idx = rand_idx[:int(len(rand_idx)*0.75)] 53 | val_idx = rand_idx[int(len(rand_idx)*0.75):] 54 | 55 | train_docs = np.array(docs)[train_idx] 56 | val_docs = np.array(docs)[val_idx] 57 | 58 | # saving data 59 | save_final_data(train_docs, "train.frcp.jsonl.xz") 60 | save_final_data(val_docs, "validation.frcp.jsonl.xz") 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | -------------------------------------------------------------------------------- /dataset_creation/federal_rules/fre.py: -------------------------------------------------------------------------------- 1 | import requests 2 | from bs4 import BeautifulSoup 3 | import datetime 4 | import random 5 | import numpy as np 6 | import lzma as xz 7 | import json 8 | 9 | def scrape_rule(url_val): 10 | page = requests.get(url_val) 11 | soup = BeautifulSoup(page.content, "html.parser") 12 | rule_content = soup.find(id="content1") 13 | if rule_content == None: 14 | return "" 15 | remove_text = rule_content.find(id="book-navigation-4940270") 16 | add_title = soup.find(id="page-title") 17 | # removing the bottom navigation titles that get scraped with the rule body 18 | stop_idx = rule_content.text.find(remove_text.text) 19 | # appending the name and editd rule content 20 | final_rule = add_title.text + rule_content.text[:stop_idx].strip() 21 | return final_rule 22 | 23 | # saving the train / val files 24 | def save_final_data(state_json_data, final_path): 25 | with xz.open(final_path, 'w') as state_data: 26 | for state_doc in state_json_data: 27 | state_data.write((json.dumps(state_doc) + "\n").encode("utf-8")) 28 | 29 | basic_url_val = "https://www.law.cornell.edu/rules/fre/" 30 | 31 | # get the number of rules in an article 32 | total_sub_rules_per_rule = [] 33 | page = requests.get(basic_url_val) 34 | soup = BeautifulSoup(page.content, "html.parser") 35 | rule_table_of_contents = soup.find(id="content1") 36 | sub_rules = rule_table_of_contents.find_all("ol", class_="bullet") 37 | for sub in sub_rules: 38 | total_rules = len(sub.find_all("li")) 39 | if total_rules <= 15: 40 | total_sub_rules_per_rule.append(total_rules) 41 | 42 | docs = [] 43 | for i in range(1, 12): 44 | specific_r_count = total_sub_rules_per_rule[i - 1] 45 | for j in range(1, specific_r_count + 1): 46 | rule_num = i*100 + j 47 | rule_url = basic_url_val + "rule_" + str(rule_num) 48 | rule_content = scrape_rule(rule_url) 49 | if rule_content == "": 50 | continue 51 | docs.append({ 52 | "url" : rule_url, #TODO how to pull google drive links 53 | "text" : rule_content, 54 | "downloaded_timestamp" : datetime.date.today().strftime("%m-%d-%Y"), 55 | "created_timestamp" : "" 56 | }) 57 | 58 | # train / test split from process_tax_corpus_jhu.py 59 | random.seed(0) 60 | rand_idx = list(range(len(docs))) 61 | random.shuffle(rand_idx) 62 | 63 | train_idx = rand_idx[:int(len(rand_idx)*0.75)] 64 | val_idx = rand_idx[int(len(rand_idx)*0.75):] 65 | 66 | train_docs = np.array(docs)[train_idx] 67 | val_docs = np.array(docs)[val_idx] 68 | 69 | # saving data 70 | save_final_data(train_docs, "train.fre.jsonl.xz") 71 | save_final_data(val_docs, "validation.fre.jsonl.xz") 72 | 73 | 74 | 75 | 76 | 77 | -------------------------------------------------------------------------------- /dataset_creation/founder_docs/scrape_founding_docs.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | url= "https://founders.archives.gov/Metadata/founders-online-metadata.json" 4 | url_template = "https://www.constituteproject.org/constitution/{const_id}?lang=en" 5 | 6 | from bs4 import BeautifulSoup 7 | from bs4 import BeautifulSoup 8 | import requests_cache 9 | import os 10 | import time 11 | import os 12 | from tqdm import tqdm 13 | import numpy as np 14 | import json 15 | import random 16 | import datetime 17 | requests = requests_cache.CachedSession('scotus') 18 | headers = {'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_10_1) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/39.0.2171.95 Safari/537.36'} 19 | try: 20 | import lzma as xz 21 | except ImportError: 22 | import pylzma as xz 23 | 24 | cache = "./cache" 25 | 26 | if not os.path.exists(cache): 27 | os.mkdir(cache) 28 | 29 | 30 | response = requests.get(url, headers=headers) 31 | ids = [r["permalink"] for r in response.json()] 32 | overwrite = True 33 | open_type = 'w' if overwrite else 'a' 34 | train_f = xz.open("./cache/train.founding_docs.xz", open_type) 35 | val_f = xz.open("./cache/validation.founding_docs.xz", open_type) 36 | 37 | for url in tqdm(ids): 38 | # url = url_template.format(const_id= _id) 39 | tag = url.split("documents/")[-1] 40 | url = f"https://founders.archives.gov/API/docdata/{tag}" 41 | response = requests.get(url, headers=headers) 42 | if not response.from_cache: 43 | time.sleep(.1) 44 | 45 | try: 46 | _dict = response.json() 47 | except: 48 | print("Problem loading response") 49 | print(response.content) 50 | continue 51 | text = f"Title: {_dict['title']}\nFrom: {','.join(_dict['authors'])}\nTo: {','.join(_dict['recipients'])}\n\n{_dict['content']}" 52 | if _dict['date-from'] is not None: 53 | created_date = "-".join(_dict['date-from'].split("-")[1:]) + "-" + _dict['date-from'].split("-")[0] 54 | else: 55 | created_date = "" 56 | 57 | datapoint = { 58 | "text" : text, 59 | "created_timestamp" : created_date, 60 | "downloaded_timestamp" : datetime.date.today().strftime("%m-%d-%Y"), 61 | "url" : url 62 | } 63 | 64 | # This dataset is already heavily US biased, so it would be really weird not to train on the US 65 | # constitution 66 | if random.random() > .75: 67 | val_f.write((json.dumps(datapoint) + "\n").encode("utf-8")) 68 | else: 69 | train_f.write((json.dumps(datapoint) + "\n").encode("utf-8")) 70 | 71 | -------------------------------------------------------------------------------- /dataset_creation/ftcadvisories/scrape_ftc_advisories.py: -------------------------------------------------------------------------------- 1 | from bs4 import BeautifulSoup 2 | from bs4 import BeautifulSoup 3 | import requests_cache 4 | import os 5 | import time 6 | import os 7 | import textract 8 | import numpy as np 9 | import json 10 | import random 11 | import datetime 12 | requests = requests_cache.CachedSession('scotus') 13 | headers = {'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_10_1) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/39.0.2171.95 Safari/537.36'} 14 | try: 15 | import lzma as xz 16 | except ImportError: 17 | import pylzma as xz 18 | 19 | cache = "./cache" 20 | 21 | if not os.path.exists(cache): 22 | os.mkdir(cache) 23 | url_template = "https://www.ftc.gov/policy/advisory-opinions?page={page}" 24 | 25 | 26 | overwrite = True 27 | open_type = 'w' if overwrite else 'a' 28 | train_f = xz.open("./cache/train.ftc_advisory_opinions.xz", open_type) 29 | val_f = xz.open("./cache/validation.ftc_advisory_opinions.xz", open_type) 30 | 31 | for page in range(0, 20): 32 | print(page) 33 | 34 | url = url_template.format(page = page) 35 | response = requests.get(url, headers=headers) 36 | soup= BeautifulSoup(response.text, "html.parser") 37 | potential_links = soup.select("a[href$='.pdf']") 38 | # import pdb; pdb.set_trace() 39 | potential_links = [link for link in potential_links if "advisory_opinions" in link["href"]] 40 | # potential_dates = [d.get_text() for d in soup.find_all("span", {"class" : "date-display-single"})] 41 | 42 | # assert len(potential_dates) == len(potential_links) 43 | if len(potential_links) == 0: 44 | print(f"SKipping for year {page} because no pdf links") 45 | continue 46 | 47 | for link in potential_links: 48 | try: 49 | response = requests.get(link["href"], headers=headers) 50 | except: 51 | print(f"PROBLEM GETTING {link}") 52 | time.sleep(random.random()*5) 53 | continue 54 | 55 | if not response.from_cache: 56 | time.sleep(random.random()*2.) 57 | 58 | 59 | with open(f'{cache}/{link["href"].split("/")[-1]}', 'wb') as f: 60 | f.write(response.content) 61 | 62 | try: 63 | text = str(textract.process(f'{cache}/{link["href"].split("/")[-1]}')) 64 | except: 65 | print(f"Problem with {link['href']}") 66 | continue 67 | if len(text) < len('\\x0c\\x0c') * 2: 68 | try: 69 | text = str(textract.process(f'{cache}/{link["href"].split("/")[-1]}', method='tesseract')) 70 | except: 71 | print(f"Problem with {link['href']}") 72 | continue 73 | if len(text) < len('\\x0c\\x0c') * 2: 74 | continue 75 | 76 | os.remove(f'{cache}/{link["href"].split("/")[-1]}') 77 | datapoint = { 78 | "text" : text, 79 | "created_timestamp" : "", 80 | "downloaded_timestamp" : datetime.date.today().strftime("%m-%d-%Y"), 81 | "url" : link["href"] 82 | } 83 | # import pdb; pdb.set_trace() 84 | if random.random() > .75: 85 | val_f.write((json.dumps(datapoint) + "\n").encode("utf-8")) 86 | else: 87 | train_f.write((json.dumps(datapoint) + "\n").encode("utf-8")) -------------------------------------------------------------------------------- /dataset_creation/oig_reports/scrape_oig_reports.py: -------------------------------------------------------------------------------- 1 | url = "https://archive.org/download/us-inspectors-general.bulk/us-inspectors-general.bulk.zip" 2 | 3 | import zipfile 4 | import os 5 | import datetime 6 | import random 7 | try: 8 | import lzma as xz 9 | except ImportError: 10 | import pylzma as xz 11 | 12 | cache_path = "./cache/us-inspectors-general.bulk.zip" 13 | if not os.path.exists(cache_path): 14 | import wget 15 | filename = wget.download(url, out="./cache/") 16 | 17 | import os 18 | import zipfile 19 | overwrite = True 20 | open_type = 'w' if overwrite else 'a' 21 | train_f = xz.open("./cache/train.oig.xz", open_type) 22 | val_f = xz.open("./cache/validation.oig.xz", open_type) 23 | import os, json 24 | 25 | i = 0 26 | with zipfile.ZipFile(cache_path) as z: 27 | for filename in z.namelist(): 28 | if not os.path.isdir(filename) and ".txt" in filename: 29 | # read the file 30 | with z.open(filename) as f: 31 | year = filename.split("/")[1] 32 | text = str(f.read()) 33 | 34 | datapoint = { 35 | "url" : url, 36 | "text" : text, 37 | "downloaded_timestamp" : datetime.date.today().strftime("%m-%d-%Y"), 38 | "created_timestamp" : year 39 | } 40 | 41 | if random.random() > .75: 42 | val_f.write((json.dumps(datapoint) + "\n").encode("utf-8")) 43 | else: 44 | train_f.write((json.dumps(datapoint) + "\n").encode("utf-8")) 45 | i += 1 46 | 47 | if i % 5000 == 0: print(i) -------------------------------------------------------------------------------- /dataset_creation/olc/scrape_olc_memos.py: -------------------------------------------------------------------------------- 1 | from bs4 import BeautifulSoup 2 | import requests 3 | import os 4 | import time 5 | import os 6 | import textract 7 | import numpy as np 8 | import json 9 | import random 10 | import datetime 11 | 12 | try: 13 | import lzma as xz 14 | except ImportError: 15 | import pylzma as xz 16 | 17 | cache = "./cache" 18 | 19 | if not os.path.exists(cache): 20 | os.mkdir(cache) 21 | base_url = "https://www.justice.gov/olc/opinions" 22 | 23 | pages = 138 24 | docs = [] 25 | 26 | def _get_opinions(soup): 27 | g_data = soup.findAll("td", {"class": "views-field-field-opinion-attachment-file"}) 28 | g_dates = soup.findAll("span", {"class" : "date-display-single"}) 29 | for data, date in zip(g_data, g_dates): 30 | link = "https://www.justice.gov/" + data.find("a")["href"] 31 | tag = link.split("/")[-2] 32 | if not os.path.exists(f'{cache}/{tag}.pdf'): 33 | print(link) 34 | pdf = requests.get(link) 35 | 36 | with open(f'{cache}/{tag}.pdf', 'wb') as f: 37 | f.write(pdf.content) 38 | 39 | text = str(textract.process(f'{cache}/{tag}.pdf')) 40 | 41 | issuance_date = date.text 42 | print(issuance_date) 43 | issuance_date = issuance_date.replace("/", "-") 44 | docs.append({ 45 | "text" : text, 46 | "created_timestamp" : issuance_date, 47 | "downloaded_timestamp" : datetime.date.today().strftime("%m-%d-%Y"), 48 | "url" : link 49 | }) 50 | 51 | 52 | pages = 138 53 | page2 = requests.get(base_url) 54 | 55 | soup = BeautifulSoup(page2.text, "lxml") 56 | button_next = soup.find("a", {"title": "Go to next page"}, href=True) 57 | 58 | _get_opinions(soup) 59 | 60 | while button_next: 61 | time.sleep(2)#delay time requests are sent so we don\'t get kicked by server 62 | url2 = "https://www.justice.gov{0}".format(button_next["href"]) 63 | page2=requests.get(url2) 64 | soup=BeautifulSoup(page2.text,"lxml") 65 | _get_opinions(soup) 66 | button_next = soup.find("a", {"title": "Go to next page"}, href=True) 67 | 68 | def save_to_processed(train, val, source_name, out_path): 69 | if not os.path.exists(out_path): 70 | os.makedirs(out_path) 71 | tf = os.path.join(out_path, f"train.{source_name}.jsonl") 72 | with open(tf, mode='w', encoding='utf-8') as out_file: 73 | for line in train: 74 | out_file.write(json.dumps(line) + "\n") 75 | print(f"Written {len(train)} documents to {tf}") 76 | 77 | vf = os.path.join(out_path, f"validation.{source_name}.jsonl") 78 | with open(vf, mode='w', encoding='utf-8') as out_file: 79 | for line in val: 80 | out_file.write(json.dumps(line) + "\n") 81 | print(f"Written {len(val)} documents to {vf}") 82 | 83 | # now compress with lib 84 | print("compressing files...") 85 | with open(vf, 'rb') as f, open(vf+".xz", 'wb') as out: 86 | out.write(xz.compress(bytes(f.read()))) 87 | with open(tf, 'rb') as f, open(tf+".xz", 'wb') as out: 88 | out.write(xz.compress(bytes(f.read()))) 89 | print("compressed") 90 | 91 | random.seed(0) # important for shuffling 92 | rand_idx = list(range(len(docs))) 93 | random.shuffle(rand_idx) 94 | 95 | train_idx = rand_idx[:int(len(rand_idx)*0.75)] 96 | val_idx = rand_idx[int(len(rand_idx)*0.75):] 97 | 98 | train_docs = np.array(docs)[train_idx] 99 | val_docs = np.array(docs)[val_idx] 100 | 101 | save_to_processed(train_docs, val_docs, "olcmemos", "./cache/") 102 | 103 | -------------------------------------------------------------------------------- /dataset_creation/process_cfr.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Processes CFR XML files. 3 | # Raw XML can be downloaded from: https://www.govinfo.gov/bulkdata/CFR/2020/CFR-2020.zip 4 | 5 | import xml.etree.ElementTree as ET 6 | import glob 7 | import os 8 | import json 9 | import random 10 | from tqdm import tqdm 11 | random.seed(0) # important for shuffling 12 | 13 | # FILE PATHS 14 | RAW_PATH = "../data/cfr/raw" # path to title folders 15 | OUT_PATH = "../data/cfr/processed" # path to write data out to 16 | 17 | 18 | files = glob.glob(os.path.join(RAW_PATH, "title-*/*.xml")) 19 | print(f"{len(files)} total documents.") 20 | outputs = [] 21 | for file in tqdm(files): 22 | doc_name = file.split("/")[-1] 23 | tree = ET.parse(file) 24 | all_text = ET.tostring(tree.getroot(), encoding='utf-8', method='text') 25 | outputs.append({ 26 | "text": all_text.decode("utf-8"), 27 | "url": "https://www.govinfo.gov/bulkdata/CFR/2020/CFR-2020.zip", 28 | "timestamp": "08-17-2021" 29 | }) 30 | 31 | random.shuffle(outputs) 32 | train = outputs[:int(len(files)*0.75)] 33 | val = outputs[int(len(files)*0.75):] 34 | 35 | # save to processed 36 | tf = os.path.join(OUT_PATH, "train.cfr.jsonl") 37 | with open(tf, "w") as out_file: 38 | for o in train: 39 | out_file.write(json.dumps(o) + "\n") 40 | print(f"Written {len(train)} documents to {tf}") 41 | 42 | vf = os.path.join(OUT_PATH, "validation.cfr.jsonl") 43 | with open(vf, "w") as out_file: 44 | for o in val: 45 | out_file.write(json.dumps(o) + "\n") 46 | print(f"Written {len(val)} documents to {vf}") -------------------------------------------------------------------------------- /dataset_creation/process_scotus_oral_arguments.py: -------------------------------------------------------------------------------- 1 | # Script for extracting transcript data from SCOTUS Oral Arguments. 2 | # Source: https://github.com/walkerdb/supreme_court_transcripts/releases/tag/2021-08-14 3 | # We prepend the speaker to each line of text. Where the speaker is unknown, we use "Speaker" 4 | 5 | import glob 6 | import json 7 | import os 8 | import random 9 | from tqdm import tqdm 10 | from dateutil import parser 11 | import re 12 | import lzma 13 | import datetime# import datetime 14 | 15 | 16 | URL = "https://github.com/walkerdb/supreme_court_transcripts/releases/tag/2021-08-14" 17 | IN_DIR = "../../data/scotus_oral/supreme_court_transcripts-2021-08-14/oyez/cases/" 18 | OUT_DIR = "../../data/scotus_oral/processed" 19 | 20 | def process_transcript(d): 21 | doc = [] 22 | try: 23 | for el in d["transcript"]['sections']: 24 | for turn in el['turns']: 25 | speaker = "Speaker" 26 | if turn['speaker'] is not None: 27 | speaker = turn['speaker']['name'] 28 | text = [t['text'] for t in turn['text_blocks']] 29 | text = " ".join(text) 30 | text = f"{speaker}: {text}" 31 | doc.append(text) 32 | except: 33 | print(d) 34 | return "\n".join(doc) 35 | 36 | def save_to_file(data, fpath): 37 | with open(fpath, "w") as out_file: 38 | for x in data: 39 | out_file.write(json.dumps(x) + "\n") 40 | print(f"Written {len(data)} to {fpath}") 41 | 42 | def main(): 43 | 44 | in_files = glob.glob(os.path.join(IN_DIR, "*t*.json")) 45 | 46 | docs = [] 47 | for f in tqdm(in_files): 48 | data = json.load(open(f)) 49 | if data['unavailable'] or data['transcript'] is None: 50 | continue 51 | text = process_transcript(data) 52 | 53 | # Get oral argument date 54 | title = data['title'] 55 | match = re.search(r'(Jan(uary)?|Feb(ruary)?|Mar(ch)?|Apr(il)?|May|Jun(e)?|Jul(y)?|Aug(ust)?|Sep(tember)?|Oct(ober)?|Nov(ember)?|Dec(ember)?)\s+\d{2},\s+\d{4}', title) 56 | oral_date = parser.parse(match.group()).strftime("%m-%d-%Y") 57 | 58 | doc = { 59 | "url": URL, 60 | "created_timestamp": oral_date, 61 | "timestamp": datetime.date.today().strftime("%m-%d-%Y"), 62 | "text": text 63 | } 64 | docs.append(doc) 65 | 66 | # shuffle and split into train / validation 67 | random.seed(0) 68 | random.shuffle(docs) 69 | train = docs[:int(len(docs)*0.75)] 70 | validation = docs[int(len(docs)*0.75):] 71 | 72 | save_to_file(train, os.path.join(OUT_DIR, "train.scotus_oral.jsonl")) 73 | save_to_file(validation, os.path.join(OUT_DIR, "validation.scotus_oral.jsonl")) 74 | 75 | if __name__ == "__main__": 76 | main() -------------------------------------------------------------------------------- /dataset_creation/r_legaladvice/scrape_r_legaladvice.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import time 3 | import random 4 | import datetime 5 | import json 6 | try: 7 | import lzma as xz 8 | except ImportError: 9 | import pylzma as xz 10 | subreddits = ['legaladvice', 'legaladviceofftopic'] 11 | maxThings = -1 12 | printWait = 2 13 | requestSize = 100 14 | from profanity_check import predict, predict_prob 15 | 16 | 17 | def requestJSON(url): 18 | while True: 19 | try: 20 | r = requests.get(url) 21 | if r.status_code != 200: 22 | print('error code', r.status_code) 23 | time.sleep(5) 24 | continue 25 | else: 26 | break 27 | except Exception as e: 28 | print(e) 29 | time.sleep(5) 30 | continue 31 | return r.json() 32 | overwrite = False 33 | open_type = 'w' if overwrite else 'a' 34 | train_f = xz.open("./cache/train.r_legaladvice.xz", open_type) 35 | val_f = xz.open("./cache/validation.r_legaladvice.xz", open_type) 36 | 37 | for subreddit in subreddits: 38 | 39 | meta = requestJSON('https://api.pushshift.io/meta') 40 | limitPerMinute = meta['server_ratelimit_per_minute'] 41 | requestWait = 60 / limitPerMinute 42 | 43 | print('server_ratelimit_per_minute', limitPerMinute) 44 | 45 | 46 | 47 | things = ('submission',) 48 | def get_comments_for_post(post_id): 49 | url = 'https://api.pushshift.io/reddit/comment/search?link_id='\ 50 | + post_id \ 51 | + f'&metadata=true&limit=20000&subreddit={subreddit}' 52 | jsond = requestJSON(url) 53 | time.sleep(requestWait) 54 | return jsond['data'] 55 | 56 | for thing in things: 57 | i = 0 58 | 59 | print('\n[starting', thing + 's]') 60 | 61 | if maxThings < 0: 62 | 63 | url = 'https://api.pushshift.io/reddit/search/'\ 64 | + thing + '/?subreddit='\ 65 | + subreddit\ 66 | + '&metadata=true&size=0' 67 | 68 | jsond = requestJSON(url) 69 | 70 | totalResults = jsond['metadata']['total_results'] 71 | print('total ' + thing + 's', 'in', subreddit,':', totalResults) 72 | else: 73 | totalResults = maxThings 74 | print('downloading most recent', maxThings) 75 | 76 | 77 | created_utc = '1612822911' if subreddit == 'legaladvice' else '' 78 | 79 | startTime = time.time() 80 | timePrint = startTime 81 | while True: 82 | url = 'http://api.pushshift.io/reddit/search/'\ 83 | + thing + '/?subreddit=' + subreddit\ 84 | + '&size=' + str(requestSize)\ 85 | + '&before=' + str(created_utc) 86 | 87 | jsond = requestJSON(url) 88 | 89 | if len(jsond['data']) == 0: 90 | break 91 | 92 | doneHere = False 93 | for post in jsond['data']: 94 | created_utc = post["created_utc"] 95 | # f.write(str(post) + '\n') 96 | i += 1 97 | comments = get_comments_for_post(post['id']) 98 | 99 | filtered_comments = [comment for comment in comments if comment['score'] >= 8 \ 100 | and predict_prob([comment['body']])[0] < .8 \ 101 | and '[removed]' \ 102 | not in comment['body'] \ 103 | and '[deleted]' \ 104 | not in comment['body'] \ 105 | and post['id'] in comment['parent_id'] 106 | ] 107 | 108 | if len(filtered_comments) == 0: 109 | continue 110 | 111 | if 'selftext' in post: 112 | selector = 'selftext' 113 | elif 'text' in post: 114 | selector = 'text' 115 | elif 'body' in post: 116 | selector = 'body' 117 | else: 118 | print("Couldn't find a selector for post text, continuing") 119 | print(post) 120 | continue 121 | post_text = post[selector] 122 | text = f"Title: {post['title']}\nQuestion:{post_text}\n" 123 | if 'link_flair_text' in post and post['link_flair_text']: 124 | text += f"Topic:\n{post['link_flair_text']}\n" 125 | elif 'link_flair_richtext' in post and post['link_flair_richtext']: 126 | text += f"Topic:\n{post['link_flair_text']}\n" 127 | 128 | for q, comment in enumerate(sorted(filtered_comments, key=lambda x: x['score'], reverse=True)): 129 | text += f'Answer #{q+1}: {comment["body"]}' 130 | 131 | datapoint = { 132 | "url" : post['url'], 133 | "text" : text, 134 | "downloaded_timestamp" : datetime.date.today().strftime("%m-%d-%Y"), 135 | "created_timestamp" : datetime.datetime.fromtimestamp(created_utc).strftime("%m-%d-%Y") 136 | } 137 | if random.random() > .75: 138 | val_f.write((json.dumps(datapoint) + "\n").encode("utf-8")) 139 | else: 140 | train_f.write((json.dumps(datapoint) + "\n").encode("utf-8")) 141 | 142 | if i >= totalResults: 143 | doneHere = True 144 | break 145 | 146 | if doneHere: 147 | break 148 | 149 | if time.time() - timePrint > printWait: 150 | timePrint = time.time() 151 | percent = i / totalResults * 100 152 | 153 | timePassed = time.time() - startTime 154 | 155 | print('{:.2f}'.format(percent) + '%', '|', 156 | time.strftime("%H:%M:%S", time.gmtime(timePassed))) 157 | 158 | 159 | time.sleep(requestWait) 160 | 161 | train_f.close() 162 | val_f.close() 163 | -------------------------------------------------------------------------------- /dataset_creation/scotus_dockets/scrape_scotus_dockets.py: -------------------------------------------------------------------------------- 1 | from bs4 import BeautifulSoup 2 | from bs4 import BeautifulSoup 3 | import requests_cache 4 | import os 5 | import time 6 | import os 7 | import textract 8 | import numpy as np 9 | import json 10 | import random 11 | import datetime 12 | requests = requests_cache.CachedSession('scotus') 13 | headers = {'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_10_1) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/39.0.2171.95 Safari/537.36'} 14 | try: 15 | import lzma as xz 16 | except ImportError: 17 | import pylzma as xz 18 | 19 | cache = "./cache" 20 | 21 | if not os.path.exists(cache): 22 | os.mkdir(cache) 23 | url_template_default = "https://www.supremecourt.gov/search.aspx?filename=/docket/docketfiles/html/public/{year}-{n}.html" 24 | aurl_template = "https://www.supremecourt.gov/search.aspx?filename=/docket/docketfiles/html/public/{year}A{n}.html" 25 | 26 | 27 | overwrite = True 28 | open_type = 'w' if overwrite else 'a' 29 | train_f = xz.open("./cache/train.scotus_docket_entries.xz", open_type) 30 | val_f = xz.open("./cache/validation.scotus_docket_entries.xz", open_type) 31 | 32 | for year in list(range(18, 22))[::-1]: 33 | 34 | start_ns = [0, 5001, "A"] 35 | # print(year) 36 | # if year == 21: 37 | # vals = list(range(1,558)) + list(range(5001)) 38 | # elif year == 20: 39 | # vals = range(1, 8478) 40 | # elif year == 19: 41 | # vals = range(1, 8930) 42 | # elif year == 18: 43 | # vals = range(1, 9849) 44 | # else: 45 | # raise ValueError("blerg") 46 | 47 | for start_n in start_ns: 48 | if start_n == "A": 49 | start_n = 0 50 | url_template = aurl_template 51 | else: 52 | url_template = url_template_default 53 | 54 | no_link_count = 0 55 | for n in range(start_n, start_n+5000): 56 | url = url_template.format(year=year, n=n) 57 | response = requests.get(url, headers=headers) 58 | soup= BeautifulSoup(response.text, "html.parser") 59 | potential_links = soup.select("a[href$='.pdf']") 60 | # import pdb; pdb.set_trace() 61 | potential_links = [link for link in potential_links if "DocketPDF" in link["href"]] 62 | 63 | if len(potential_links) == 0: 64 | print(f"SKipping for year {year} at n {n} because no pdf links") 65 | no_link_count += 1 66 | if no_link_count <= 10: 67 | continue 68 | else: 69 | break 70 | 71 | no_link_count = 0 72 | 73 | for link in potential_links: 74 | try: 75 | response = requests.get(link["href"], headers=headers) 76 | except: 77 | print(f"PROBLEM GETTING {link}") 78 | time.sleep(random.random()*5) 79 | continue 80 | 81 | if not response.from_cache: 82 | time.sleep(random.random()*2.) 83 | 84 | filepath = f'{cache}/{link["href"].split("/")[-1]}' 85 | if len(filepath) > 200: 86 | import uuid 87 | filepath = f'{cache}/{str(uuid.uuid4())}' 88 | 89 | with open(filepath, 'wb') as f: 90 | f.write(response.content) 91 | 92 | try: 93 | text = str(textract.process(filepath)) 94 | except: 95 | print(f"Problem with {link['href']}") 96 | continue 97 | if len(text) < len('\\x0c\\x0c') * 2: 98 | try: 99 | text = str(textract.process(filepath, method='tesseract')) 100 | except: 101 | print(f"Problem with {link['href']}") 102 | continue 103 | if len(text) < len('\\x0c\\x0c') * 2: 104 | continue 105 | 106 | os.remove(filepath) 107 | datapoint = { 108 | "text" : text, 109 | "created_timestamp" : f'{link["href"].split("/")[-1][4:6]}-{link["href"].split("/")[-1][6:8]}-{link["href"].split("/")[-1][0:4]}', 110 | "downloaded_timestamp" : datetime.date.today().strftime("%m-%d-%Y"), 111 | "url" : link["href"] 112 | } 113 | if random.random() > .75: 114 | val_f.write((json.dumps(datapoint) + "\n").encode("utf-8")) 115 | else: 116 | train_f.write((json.dumps(datapoint) + "\n").encode("utf-8")) -------------------------------------------------------------------------------- /dataset_creation/scrape_nlrb_decisions.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import requests 3 | import re 4 | from tqdm import tqdm 5 | import wget 6 | import os 7 | import json 8 | import PyPDF2 9 | import glob 10 | import datetime 11 | import random 12 | 13 | from multiprocessing import Pool 14 | 15 | 16 | DIR = "../data/nlrb/scraping" 17 | OUT_DIR = "../data/nlrb/processed" 18 | BASE_URL = "https://www.nlrb.gov/cases-decisions/decisions/board-decisions?search_term=&op=Search&volume=-1&slip_opinion_number=&page_number=$PAGE_NUMBER&items_per_page=100&form_build_id=form-EXFTGMwfIM0yO2L6ENa30_RuqwbWvmk5YR1UYcvgqsA&form_id=board_decisions_form" 19 | 20 | 21 | def get_urls(args): 22 | all_matches = [] 23 | for page_num in tqdm(range(661)): 24 | url = BASE_URL.replace("$PAGE_NUMBER", str(page_num)) 25 | page = requests.get(url) 26 | text = page.text 27 | #p = re.compile("href=\"(https\:\/\/apps\.nlrb\.gov\/link\/document\.aspx\/[a-z0-9]*)\"") 28 | p = re.compile(" ([0-9]*\/[0-9]*\/[0-9]*) <\/td>\s*[0-9]* NLRB No\. [0-9]*<\/td>\s*\s* Sections -> section format 107 | # Tennesse has Chapters -> Chapter -> Sections -> section 108 | # some others have just Sections -> section (no Chapters) 109 | # alaska uses secs instead of sections 110 | # TO DO: montana has parts and part 111 | 112 | # some have sects instead of sections 113 | all_text = _unroll(file_content_dict) 114 | 115 | 116 | # if (isinstance(file_content_dict, list)) and (isinstance(file_content_dict[0], dict)): 117 | # top_keys = file_content_dict[0].keys() 118 | # if "link" in top_keys: 119 | # needed_link = file_content_dict[0]["link"] # for output 120 | # if "chapters" in top_keys: 121 | # # go to one level down look for sections 122 | # # multiple sections - need to parse here 123 | # for sect_elem in file_content_dict[0]["chapters"]: 124 | # sub_keys = sect_elem.keys() 125 | # if not sect_elem: 126 | # print (f" *** Empty dict hit -- {ind_file}") 127 | # continue 128 | # if "raws" in sub_keys: 129 | # to_parse_sections.append(sect_elem) 130 | # continue 131 | 132 | # if "sections" in sub_keys: 133 | # to_parse_sections = sect_elem["sections"] 134 | # elif "secs" in sub_keys: 135 | # to_parse_sections = sect_elem["secs"] 136 | # else: 137 | # print (f"*** subkeys {sub_keys} ") 138 | # print (f"*** ERROR - No Sections under Chapters in {ind_file} ") 139 | # elif "sections" in top_keys: 140 | # to_parse_sections = file_content_dict[0]["sections"] 141 | 142 | # to_parse_sections has the list of all sections, combine texts from all section elements 143 | 144 | # some section (especially the last one) is empty - check for these 145 | # for ind_sect in to_parse_sections: 146 | # if ind_sect: 147 | # if "raws" in ind_sect.keys(): 148 | # all_text = all_text + " "+ " ".join(ind_sect['raws']) 149 | 150 | # create dictionary for output 151 | out_dict = {} 152 | out_dict ["url"] = url # should the key be url or link 153 | out_dict ["text"] = all_text 154 | #out_dict ["created_timestamp"] = TIME_STAMP 155 | out_dict ["created_timestamp"] = ind_file.split("_")[-1].replace(".json", "") 156 | out_dict ["downloaded_timestamp"] = datetime.date.today().strftime("%m-%d-%Y") 157 | out_dict ["state_year"] = state_name_year 158 | 159 | if random.random() > .75: 160 | val_f.write((json.dumps(out_dict) + "\n").encode("utf-8")) 161 | else: 162 | train_f.write((json.dumps(out_dict) + "\n").encode("utf-8")) 163 | -------------------------------------------------------------------------------- /dataset_creation/state_codes/state_codes_from_scratch.py: -------------------------------------------------------------------------------- 1 | import requests 2 | from bs4 import BeautifulSoup 3 | import datetime 4 | import random 5 | import numpy as np 6 | import lzma as xz 7 | import json 8 | 9 | # The other file was run to grab the state codes. This captures 10 | # all the states with no errors, but takes way too long to run (timing out). If 11 | # someone has capacity to get fix the bandwidth issues on this, feel free to take a stab! 12 | 13 | def state_code_scrape(url_val): 14 | page = requests.get(url_val) 15 | soup = BeautifulSoup(page.content, "html.parser") 16 | results = soup.find(id="codes-content") 17 | return results.text 18 | 19 | def year_list_scrape(url_val): 20 | page = requests.get(url_val) 21 | soup = BeautifulSoup(page.content, "html.parser") 22 | results = soup.find(id="main-content") 23 | code_years = results.find("ul") 24 | year_list = [year.text[:4] for year in code_years] 25 | return year_list 26 | 27 | # saving the train / val files 28 | def save_final_data(state_json_data, final_path): 29 | with xz.open(final_path, 'w') as state_data: 30 | for state_doc in state_json_data: 31 | state_data.write((json.dumps(state_doc) + "\n").encode("utf-8")) 32 | 33 | def create_nested_doc(url_val, base_url): 34 | final_text = "" 35 | page = requests.get(url_val) 36 | soup = BeautifulSoup(page.content, "html.parser") 37 | title_sections = soup.find("div", class_="codes-listing") 38 | if title_sections == None: 39 | return state_code_scrape(url_val) 40 | else: 41 | title_section = title_sections.find_all("li") 42 | for sec in title_section: 43 | final_link = "" 44 | further_sec = sec.find_all('a', href=True) 45 | for l in further_sec: 46 | final_link = l['href'] 47 | if final_link != "": 48 | sec_url = base_url + final_link 49 | final_text += create_nested_doc(sec_url, base_url) 50 | return final_text 51 | 52 | state_names = ["Alaska", "Alabama", "Arkansas", "Arizona", "California", "Colorado", "Connecticut", "District-of-columbia", "Delaware", "Florida", "Georgia", "Guam", "Hawaii", "Iowa", "Idaho", "Illinois", "Indiana", "Kansas", "Kentucky", "Louisiana", "Massachusetts", "Maryland", "Maine", "Michigan", "Minnesota", "Missouri", "Mississippi", "Montana", "North-carolina", "North-dakota", "Nebraska", "New-hampshire", "New-jersey", "New-mexico", "Nevada", "New-york", "Ohio", "Oklahoma", "Oregon", "Pennsylvania", "Puerto-rico", "Rhode-island", "South-carolina", "South-dakota", "Tennessee", "Texas", "Utah", "Virginia", "Virgin-Islands", "Vermont", "Washington", "Wisconsin", "West-virginia", "Wyoming"] 53 | state_names = [x.lower() for x in state_names] 54 | 55 | base_url = "https://law.justia.com" 56 | docs = [] 57 | for state in state_names: 58 | state_base_url = base_url + '/codes/' + state 59 | year_list = year_list_scrape(state_base_url) 60 | for year in year_list: 61 | constructed_doc = "" 62 | year_url = state_base_url + '/' + year 63 | constructed_doc = create_nested_doc(year_url, base_url) 64 | docs.append({ 65 | "url" : year_url, #TODO how to pull google drive links 66 | "text" : constructed_doc, 67 | "downloaded_timestamp" : datetime.date.today().strftime("%m-%d-%Y"), 68 | "created_timestamp" : "", 69 | "state_year": state + '_' + year 70 | }) 71 | print(constructed_doc) 72 | 73 | # train / test split from process_tax_corpus_jhu.py 74 | random.seed(0) 75 | rand_idx = list(range(len(docs))) 76 | random.shuffle(rand_idx) 77 | 78 | train_idx = rand_idx[:int(len(rand_idx)*0.75)] 79 | val_idx = rand_idx[int(len(rand_idx)*0.75):] 80 | 81 | train_docs = np.array(docs)[train_idx] 82 | val_docs = np.array(docs)[val_idx] 83 | 84 | # saving data 85 | save_final_data(train_docs, "train.state_codes.jsonl.xz") 86 | save_final_data(val_docs, "validation.state_codes.jsonl.xz") 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | -------------------------------------------------------------------------------- /dataset_creation/tax_corpus_jhu/process_tax_corpus_jhu.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import json 4 | import datetime 5 | import numpy as np 6 | # Assumes you've downloaded 7 | link = "https://archive.data.jhu.edu/file.xhtml?persistentId=doi:10.7281/T1/N1X6I4/D5CQ0Y&version=2.0" 8 | # into cache/ and extracted it 9 | try: 10 | import lzma as xz 11 | except ImportError: 12 | import pylzma as xz 13 | 14 | with open('cache/plrs_tc_corpus_feb25.txt', 'r') as f: 15 | docs = [{ 16 | "text" : " ".join(doc.split(" ")[1:]), 17 | "created_timestamp" : "", 18 | "downloaded_timestamp" : datetime.date.today().strftime("%m-%d-%Y"), 19 | "url" : link 20 | } for doc in f.readlines()] # first word is always a file name 21 | 22 | # Note the preprocessing in this dataset is very messy. In the future we may wish to find the raw original proceedings. 23 | 24 | def save_to_processed(train, val, source_name, out_path): 25 | if not os.path.exists(out_path): 26 | os.makedirs(out_path) 27 | tf = os.path.join(out_path, f"train.{source_name}.jsonl") 28 | with open(tf, mode='w', encoding='utf-8') as out_file: 29 | for line in train: 30 | out_file.write(json.dumps(line) + "\n") 31 | print(f"Written {len(train)} documents to {tf}") 32 | 33 | vf = os.path.join(out_path, f"validation.{source_name}.jsonl") 34 | with open(vf, mode='w', encoding='utf-8') as out_file: 35 | for line in val: 36 | out_file.write(json.dumps(line) + "\n") 37 | print(f"Written {len(val)} documents to {vf}") 38 | 39 | # now compress with lib 40 | print("compressing files...") 41 | with open(vf, 'rb') as f, open(vf+".xz", 'wb') as out: 42 | out.write(xz.compress(bytes(f.read()))) 43 | with open(tf, 'rb') as f, open(tf+".xz", 'wb') as out: 44 | out.write(xz.compress(bytes(f.read()))) 45 | print("compressed") 46 | 47 | random.seed(0) # important for shuffling 48 | rand_idx = list(range(len(docs))) 49 | random.shuffle(rand_idx) 50 | 51 | train_idx = rand_idx[:int(len(rand_idx)*0.75)] 52 | val_idx = rand_idx[int(len(rand_idx)*0.75):] 53 | 54 | train_docs = np.array(docs)[train_idx] 55 | val_docs = np.array(docs)[val_idx] 56 | 57 | save_to_processed(train_docs, val_docs, "taxrulings", "./cache/") 58 | -------------------------------------------------------------------------------- /dataset_creation/tos/process_tos.py: -------------------------------------------------------------------------------- 1 | # Processes TOS XML files. 2 | # XML with annotation tags (clause type / whether clause is unfair) can be downloaded from: http://claudette.eui.eu/ToS.zip 3 | # XML files located in OriginalTaggedDocuments subdirectory 4 | 5 | import os 6 | import json 7 | import re 8 | import random 9 | import numpy as np 10 | import datetime 11 | from tqdm import tqdm 12 | from dateutil import parser 13 | from bs4 import BeautifulSoup 14 | 15 | # FILE PATHS 16 | RAW_PATH = "../../data/tos/raw" # path to contracts 17 | OUT_PATH = "../../data/tos/processed" # path to write data (docs of original text) out to 18 | OUT_PATH_TAGGED = "../../data/tos/processed_tagged" # path to write data (docs of original text with annotation tags) out to 19 | 20 | # Regex match for date 21 | match_text_patterns = ["Last Updated Date:", "Last Updated:", "Last updated:", "last updated on", "Last revised on", "Last Revised:", "Last revised on", "Revised", "Date of Last Revision:", "Last modified:", "Last modified by", "in effect as of", "Effective Date:", "Effective:", "effective on", "Effective on", "Applicable from:"] 22 | combined_text_regex = '(?:%s)' % '|'.join(match_text_patterns) 23 | 24 | match_date_patterns = [ 25 | r'(Jan(uary)?|Feb(ruary)?|Mar(ch)?|Apr(il)?|May|Jun(e)?|Jul(y)?|Aug(ust)?|Sep(tember)?|Oct(ober)?|Nov(ember)?|Dec(ember)?)\s+\d{1,2}(th)?,\s+\d{4}', 26 | r'\d{1,2}\s+(Jan(uary)?|Feb(ruary)?|Mar(ch)?|Apr(il)?|May|Jun(e)?|Jul(y)?|Aug(ust)?|Sep(tember)?|Oct(ober)?|Nov(ember)?|Dec(ember)?)(,)?\s+\d{4}', 27 | r'\d{1,2}/\d{1,2}/\d{4}', 28 | r'\d{4}-\d{1,2}-\d{1,2}' 29 | ] 30 | combined_date_regex = '(?:%s)' % '|'.join(match_date_patterns) 31 | 32 | def save_to_processed(train, val, source_name, out_path): 33 | if not os.path.exists(out_path): 34 | os.makedirs(out_path) 35 | tf = os.path.join(out_path, f"train.{source_name}.jsonl") 36 | with open(tf, mode='w', encoding='utf-8') as out_file: 37 | for line in train: 38 | out_file.write(json.dumps(line) + "\n") 39 | print(f"Written {len(train)} documents to {tf}") 40 | 41 | vf = os.path.join(out_path, f"validation.{source_name}.jsonl") 42 | with open(vf, mode='w', encoding='utf-8') as out_file: 43 | for line in val: 44 | out_file.write(json.dumps(line) + "\n") 45 | print(f"Written {len(val)} documents to {vf}") 46 | 47 | def main(): 48 | files = [os.path.join(RAW_PATH, f) for f in os.listdir(RAW_PATH)] 49 | print(f"{len(files)} total documents.") 50 | outputs = [] 51 | outputs_tagged = [] 52 | for file in tqdm(files): 53 | print("Processing:", file) 54 | with open(file, mode='r', encoding='utf-8') as in_file: 55 | text_tagged = in_file.read() 56 | 57 | # Remove annotation tags 58 | soup = BeautifulSoup(text_tagged, features="lxml") 59 | text = soup.get_text() 60 | 61 | # Extract creation date 62 | lines = text.splitlines() 63 | date_text = None 64 | for line in lines: 65 | date_text = re.search(combined_text_regex, line) 66 | # If matched text describing date, break at current line 67 | if date_text: 68 | break 69 | creation_date = "" 70 | if date_text: 71 | match = re.search(combined_date_regex, line) 72 | # If matched date, parse match 73 | if match: 74 | creation_date = parser.parse(match.group()).strftime("%m-%d-%Y") 75 | 76 | outputs.append({ 77 | "url": "http://claudette.eui.eu/ToS.zip", 78 | "created_timestamp": creation_date, 79 | "downloaded_timestamp": datetime.date.today().strftime("%m-%d-%Y"), 80 | "text": text 81 | }) 82 | 83 | outputs_tagged.append({ 84 | "url": "http://claudette.eui.eu/ToS.zip", 85 | "created_timestamp": creation_date, 86 | "downloaded_timestamp": datetime.date.today().strftime("%m-%d-%Y"), 87 | "text": text_tagged 88 | }) 89 | 90 | outputs = np.array(outputs) 91 | outputs_tagged = np.array(outputs_tagged) 92 | 93 | random.seed(0) # important for shuffling 94 | rand_idx = list(range(len(outputs))) 95 | random.shuffle(rand_idx) 96 | 97 | train_idx = rand_idx[:int(len(rand_idx)*0.75)] 98 | val_idx = rand_idx[int(len(rand_idx)*0.75):] 99 | 100 | # Same split for docs of original text and original text with annotation tags 101 | train = outputs[train_idx] 102 | val = outputs[val_idx] 103 | 104 | train_tagged = outputs_tagged[train_idx] 105 | val_tagged = outputs_tagged[val_idx] 106 | 107 | # Save train / val to processed 108 | save_to_processed(train, val, "tos", OUT_PATH) 109 | 110 | # Save train_tagged / val_tagged to processed_tagged 111 | save_to_processed(train_tagged, val_tagged, "tos_tagged", OUT_PATH_TAGGED) 112 | 113 | if __name__ == "__main__": 114 | main() -------------------------------------------------------------------------------- /dataset_creation/un_debates/scrape_un_debates.py: -------------------------------------------------------------------------------- 1 | import os 2 | import zipfile 3 | import os 4 | import textract 5 | import random 6 | import bs4 7 | import json 8 | import datetime 9 | try: 10 | import lzma as xz 11 | except ImportError: 12 | import pylzma as xz 13 | # TODO: can't actually get have to downlaod manually 14 | url = "https://dataverse.harvard.edu/file.xhtml?fileId=4590189&version=6.0#" 15 | 16 | cache_path = "./cache/UNGDC_1970-2020.tar.gz" 17 | if not os.path.exists(cache_path): 18 | raise ValueError(f"Need to download from {url}") 19 | # import wget 20 | # filename = wget.download(url, out="./cache/") 21 | import tarfile 22 | overwrite = True 23 | open_type = 'w' if overwrite else 'a' 24 | train_f = xz.open("./cache/train.undebates.xz", open_type) 25 | val_f = xz.open("./cache/validation.undebates.xz", open_type) 26 | # if fname.endswith("tar.gz"): 27 | tar = tarfile.open(cache_path, "r:gz") 28 | tar.extractall() 29 | tar.close() 30 | for path, subdirs, files in os.walk("./TXT/"): 31 | for name in files: 32 | if not name.endswith(".txt") or name[0] == ".": 33 | continue 34 | 35 | with open(os.path.join(path, name), "r") as f: 36 | text = str(f.read()) 37 | 38 | datapoint = { 39 | "text" : text, 40 | "created_timestamp" : name.split("_")[-1].replace(".txt", ""), 41 | "downloaded_timestamp" : datetime.date.today().strftime("%m-%d-%Y"), 42 | "url" : url 43 | } 44 | 45 | if random.random() > .75: 46 | val_f.write((json.dumps(datapoint) + "\n").encode("utf-8")) 47 | else: 48 | train_f.write((json.dumps(datapoint) + "\n").encode("utf-8")) 49 | -------------------------------------------------------------------------------- /dataset_creation/us_bills/process_us_bills.py: -------------------------------------------------------------------------------- 1 | # Processes US congressional bills from the 108th - 117th Congress 2 | # Bulk data pulled from: https://www.govinfo.gov/bulkdata/xml/BILLSTATUS 3 | # Requests to bulk data API endpoint are prone to request errors, 4 | # use run.sh to call this script for each session for greater fault tolerance 5 | 6 | import os 7 | import sys 8 | import json 9 | import datetime 10 | import random 11 | from urllib.request import Request, urlopen 12 | # xpath only available in lxml etree, not ElementTree 13 | from lxml import etree 14 | from tqdm import tqdm 15 | from dateutil import parser 16 | import time 17 | import argparse 18 | 19 | # BASE URL (bulk data API endpoint) 20 | BASE_URL = "https://www.govinfo.gov/bulkdata/xml/BILLSTATUS" 21 | 22 | # DATA DIRS 23 | SAVE_DIR = "../../data/us_bills/save" # path to save per session data 24 | 25 | # Request variables 26 | headers = {'User-Agent': 'Mozilla/5.0', 'Accept': 'application/xml'} 27 | 28 | 29 | def save_to_file(data, out_dir, fname): 30 | if not os.path.exists(out_dir): 31 | os.makedirs(out_dir) 32 | fpath = os.path.join(out_dir, fname) 33 | with open(fpath, 'w') as out_file: 34 | for x in data: 35 | out_file.write(json.dumps(x) + "\n") 36 | print(f"Written {len(data)} to {fpath}") 37 | 38 | 39 | def request_status_xmls(session): 40 | urls = [os.path.join(BASE_URL, str(session))] 41 | status_xmls = [] 42 | while len(urls) > 0: 43 | next_urls = [] 44 | for url in urls: 45 | print(url) 46 | request = Request(url, headers=headers) 47 | time.sleep(random.uniform(0.02, 0.05)) 48 | with urlopen(request) as response: 49 | root = etree.fromstring(response.read()) 50 | elems = root.xpath("*/file[folder='true' and name!='resources']") 51 | if len(elems) > 0: 52 | for e in elems: 53 | next_url = e.find("link").text 54 | next_urls.append(next_url) 55 | else: 56 | elems = root.xpath("*/file[mimeType='application/xml']") 57 | for e in elems: 58 | xml_url = e.find("link").text 59 | # print(xml_url) 60 | request = Request(xml_url, headers=headers) 61 | for i in range(3): # retry request max of 3 times 62 | try: 63 | time.sleep(random.uniform(0.02, 0.05)) 64 | with urlopen(request) as response: 65 | xml = etree.fromstring(response.read()) 66 | # Add xml for bill status 67 | status_xmls.append(xml) 68 | break 69 | except: 70 | print("Retrying") 71 | 72 | urls = next_urls 73 | 74 | return status_xmls 75 | 76 | 77 | def request_raw_data(status_xmls): 78 | xmls = [] 79 | for status_xml in tqdm(status_xmls): 80 | # print(status_xml) 81 | # Text versions are sorted in date order, find returns first item, which is most recent version 82 | text_info = status_xml.find("bill/textVersions/item") 83 | if text_info is not None: 84 | try: 85 | date = text_info.find("date").text 86 | except: 87 | date = "" 88 | try: 89 | xml_url = text_info.find("formats/item/url").text 90 | except: 91 | xml_url = None 92 | if xml_url: 93 | request = Request(xml_url, headers=headers) 94 | for i in range(3): 95 | try: 96 | time.sleep(random.uniform(0.02, 0.05)) 97 | with urlopen(request) as response: 98 | xml = etree.fromstring(response.read()) 99 | # print(date, xml_url) 100 | # Add tuple of (date, xml_url, xml) for raw bill text 101 | xmls.append((date, xml_url, xml)) 102 | break 103 | except: 104 | print("Retrying") 105 | 106 | return xmls 107 | 108 | 109 | def prepare_docs(xmls): 110 | docs = [] 111 | for (date, xml_url, xml) in tqdm(xmls): 112 | try: 113 | creation_date = parser.parse(date).strftime("%m-%d-%Y") 114 | except: 115 | creation_date = "" 116 | 117 | # In Python 3, use encoding='unicode' 118 | # In Python 2, use encoding='utf-8' and decode 119 | all_text = etree.tostring(xml, encoding='unicode', method='text') 120 | 121 | doc = { 122 | "url": xml_url, 123 | "created_timestamp": creation_date, 124 | "downloaded_timestamp": datetime.date.today().strftime("%m-%d-%Y"), 125 | "text": all_text 126 | } 127 | docs.append(doc) 128 | 129 | return docs 130 | 131 | 132 | def main(): 133 | args = arg_parser.parse_args() 134 | 135 | # Request raw data directly using bulk data API 136 | print("Request status xmls") 137 | status_xmls = request_status_xmls(session=args.session) 138 | print("Request raw data") 139 | xmls = request_raw_data(status_xmls) 140 | print("Prepare docs") 141 | docs = prepare_docs(xmls) 142 | 143 | save_to_file(docs, SAVE_DIR, str(args.session) + ".us_bills.jsonl") 144 | 145 | 146 | if __name__ == '__main__': 147 | arg_parser = argparse.ArgumentParser() 148 | arg_parser.add_argument('--session', type=int, default=117) 149 | main() -------------------------------------------------------------------------------- /dataset_creation/us_bills/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | python3 process_us_bills.py --session=117 3 | python3 process_us_bills.py --session=116 4 | python3 process_us_bills.py --session=115 5 | python3 process_us_bills.py --session=114 6 | python3 process_us_bills.py --session=113 7 | python3 process_us_bills.py --session=112 8 | python3 process_us_bills.py --session=111 9 | python3 process_us_bills.py --session=110 10 | python3 process_us_bills.py --session=109 11 | python3 process_us_bills.py --session=108 -------------------------------------------------------------------------------- /dataset_creation/us_bills/split_us_bills.py: -------------------------------------------------------------------------------- 1 | # Loads data files saved by session (108th-117th Congress) from the bulk data API endpoints 2 | # Splits all data into train / test data 3 | 4 | import os 5 | import glob 6 | import json 7 | import random 8 | from process_us_bills import save_to_file 9 | 10 | # DATA DIRS 11 | SAVE_DIR = "../../data/us_bills/save" # path to save per session data 12 | OUT_DIR = "../../data/us_bills/processed" # path to write data out to 13 | 14 | def main(): 15 | docs = [] 16 | # Load all sessions from saved per-session data files 17 | session_files = glob.glob(os.path.join(SAVE_DIR, "*.jsonl")) 18 | for session_file in session_files: 19 | print(session_file) 20 | with open(session_file, 'r') as file: 21 | for line in file: 22 | docs.append(json.loads(line)) 23 | 24 | # Shuffle and split into train / validation 25 | random.seed(0) 26 | random.shuffle(docs) 27 | train = docs[:int(len(docs)*0.75)] 28 | validation = docs[int(len(docs)*0.75):] 29 | 30 | print("Write train data") 31 | save_to_file(train, OUT_DIR, "train.us_bills.jsonl") 32 | print("Write validation data") 33 | save_to_file(validation, OUT_DIR, "validation.us_bills.jsonl") 34 | 35 | 36 | if __name__ == '__main__': 37 | main() -------------------------------------------------------------------------------- /dataset_creation/uscode/scrape_us_code.py: -------------------------------------------------------------------------------- 1 | url = "https://uscode.house.gov/download/releasepoints/us/pl/117/49/xml_uscAll@117-49.zip" 2 | 3 | import zipfile 4 | import os 5 | import textract 6 | import random 7 | import bs4 8 | import json 9 | import datetime 10 | try: 11 | import lzma as xz 12 | except ImportError: 13 | import pylzma as xz 14 | 15 | cache_path = "./cache/xml_uscAll@117-49.zip" 16 | if not os.path.exists(cache_path): 17 | import wget 18 | filename = wget.download(url, out="./cache/") 19 | 20 | with zipfile.ZipFile(cache_path, 'r') as zip_ref: 21 | zip_ref.extractall("./cache/") 22 | 23 | overwrite = True 24 | open_type = 'w' if overwrite else 'a' 25 | train_f = xz.open("./cache/train.uscode.xz", open_type) 26 | val_f = xz.open("./cache/validation.uscode.xz", open_type) 27 | def html2text(x): 28 | soup = bs4.BeautifulSoup(x, "lxml") 29 | return soup.get_text() 30 | 31 | docs = [] 32 | for path, subdirs, files in os.walk("./cache/"): 33 | for name in files: 34 | if not name.endswith(".xml"): 35 | continue 36 | with open(os.path.join(path, name), "r") as text_file: 37 | text = html2text(text_file.read()) 38 | # text = str(textract.process(os.path.join(path, name))) 39 | datapoint = { 40 | "text" : text, 41 | "created_timestamp" : "2021", 42 | "downloaded_timestamp" : datetime.date.today().strftime("%m-%d-%Y"), 43 | "url" : url 44 | } 45 | 46 | if random.random() > .75: 47 | val_f.write((json.dumps(datapoint) + "\n").encode("utf-8")) 48 | else: 49 | train_f.write((json.dumps(datapoint) + "\n").encode("utf-8")) -------------------------------------------------------------------------------- /pretraining/README.md: -------------------------------------------------------------------------------- 1 | 2 | # Vocab 3 | 4 | We use new_vocab.py for generating part of the vocabulary and the rest is supplemented via a random sampling of Black's law dictionary. 5 | 6 | # Chunkification and Randomization 7 | 8 | We use the chunkify_and_hd5.py file to segment and shuffle the data into shards. Note that each shard consists of a homogeneous sampling across data segments. We find that not doing this created extremely large instabilities. 9 | 10 | # Fine-tuning 11 | 12 | Fine-tuning code is in pol-finetuning-main -------------------------------------------------------------------------------- /pretraining/new_vocab.py: -------------------------------------------------------------------------------- 1 | # !pip install tokenizers 2 | import os 3 | from tokenizers import BertWordPieceTokenizer 4 | 5 | fileList = [os.path.join("./unzipped_nojson", x) for x in os.listdir("./unzipped_nojson") if x is not None] 6 | 7 | # initialize 8 | tokenizer = BertWordPieceTokenizer( 9 | clean_text=False, 10 | handle_chinese_chars=True, 11 | strip_accents=True, 12 | lowercase=True 13 | ) 14 | 15 | # and train 16 | tokenizer.train(files=fileList, vocab_size=29000, min_frequency=2, 17 | limit_alphabet=500, wordpieces_prefix='##', 18 | special_tokens=[ 19 | '[PAD]', '[UNK]', '[CLS]', '[SEP]', '[MASK]']) 20 | 21 | tokenizer.save_model('./new_bert_vocab_2/', 'bert-wordpiece') 22 | 23 | -------------------------------------------------------------------------------- /pretraining/pol-finetuning-main/README.md: -------------------------------------------------------------------------------- 1 | # pol-finetuning 2 | 3 | ## Setup 4 | - Create `pol_models` directory 5 | - Download Model4 (Models4 folder in GDrive) and pol-roberta weights from the Google Drive to the following paths in `pol_models` 6 | ``` 7 | pol_models 8 | Model4 9 | config.json 10 | pytorch_model.bin 11 | tokenizer_config.json 12 | vocab.txt 13 | ``` 14 | - Conda env file: `environment.yml` 15 | - Originally run on GCP instance with 16 | 4 x NVIDIA Tesla A100, adjust batch size according to the machine you're running on 17 | 18 | 19 | -------------------------------------------------------------------------------- /pretraining/pol-finetuning-main/environment.yml: -------------------------------------------------------------------------------- 1 | name: pol 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1 8 | - _pytorch_select=0.1=cpu_0 9 | - absl-py=1.0.0=pyhd8ed1ab_0 10 | - blas=1.0=mkl 11 | - c-ares==1.18.1 12 | - ca-certificates=2021.10.8 13 | - cudatoolkit 14 | - freetype=2.10.4 15 | - giflib=5.2.1 16 | - grpcio==1.46.1 17 | - importlib-metadata==4.11.3 18 | - intel-openmp 19 | - jbig==2.1 20 | - jpeg==9e 21 | - lcms2=2.12 22 | - ld_impl_linux-64==2.36.1 23 | - lerc=3.0 24 | - libblas=3.9.0 25 | - libcblas=3.9.0 26 | - libdeflate=1.10 27 | - libffi=3.4.2 28 | - libgfortran-ng 29 | - libgfortran5 30 | - liblapack==3.9.0 31 | - libopenblas=0.3.20 32 | - libpng=1.6.37 33 | - libprotobuf==3.20.1 34 | - libtiff==4.3.0 35 | - libuv==1.43.0 36 | - libwebp==1.2.2 37 | - libwebp-base==1.2.2 38 | - libxcb=1.13 39 | - libzlib=1.2.11 40 | - lz4-c=1.9.3 41 | - markdown=3.3.7=pyhd8ed1ab_0 42 | - mkl 43 | - ncurses=6.3 44 | - ninja=1.10.2 45 | - numpy=1.21.6 46 | - openjpeg==2.4.0 47 | - openssl=3.0.3 48 | - pillow==9.1.0 49 | - pip=22.1 50 | - pthread-stubs==0.4 51 | - python=3.7.12 52 | - python_abi=3.7=2_cp37m 53 | - pytorch=1.7.1 54 | - readline=8.1 55 | - setuptools==62.2.0 56 | - six=1.16.0=pyh6c4a22f_0 57 | - sqlite 58 | - tensorboard=1.15.0=py37_0 59 | - tk==8.6.12 60 | - torchaudio=0.7.2=py37 61 | - torchvision=0.8.2 62 | - typing_extensions=4.2.0=pyha770c72_1 63 | - werkzeug=2.1.2=pyhd8ed1ab_1 64 | - wheel=0.37.1=pyhd8ed1ab_0 65 | - xorg-libxau==1.0.9 66 | - xorg-libxdmcp=1.1.3 67 | - xz==5.2.5 68 | - zipp=3.8.0=pyhd8ed1ab_0 69 | - zlib==1.2.11 70 | - zstd==1.5.2 71 | - pip: 72 | - aiohttp==3.8.1 73 | - aiosignal==1.2.0 74 | - apex==0.1 75 | - async-timeout==4.0.2 76 | - asynctest==0.13.0 77 | - attrs==21.4.0 78 | - certifi==2021.10.8 79 | - charset-normalizer==2.0.12 80 | - click==8.1.3 81 | - datasets==2.1.0 82 | - dill==0.3.4 83 | - docker-pycreds==0.4.0 84 | - filelock==3.6.0 85 | - frozenlist==1.3.0 86 | - fsspec==2022.3.0 87 | - gitdb==4.0.9 88 | - gitpython==3.1.27 89 | - huggingface-hub==0.5.1 90 | - idna==3.3 91 | - joblib==1.1.0 92 | - multidict==6.0.2 93 | - multiprocess==0.70.12.2 94 | - nltk==3.7 95 | - packaging==21.3 96 | - pandas==1.3.5 97 | - pathtools==0.1.2 98 | - promise==2.3 99 | - protobuf==3.20.1 100 | - psutil==5.9.0 101 | - pyarrow==7.0.0 102 | - pyparsing==3.0.8 103 | - python-dateutil==2.8.2 104 | - pytz==2022.1 105 | - pyyaml==6.0 106 | - regex==2022.4.24 107 | - requests==2.27.1 108 | - responses==0.18.0 109 | - sacremoses==0.0.49 110 | - scikit-learn==1.0.2 111 | - scipy==1.7.3 112 | - sentry-sdk==1.5.10 113 | - setproctitle==1.2.3 114 | - shortuuid==1.0.8 115 | - smmap==5.0.0 116 | - threadpoolctl==3.1.0 117 | - tokenizers==0.12.1 118 | - tqdm==4.64.0 119 | - transformers==4.18.0 120 | - urllib3==1.26.9 121 | - wandb==0.12.15 122 | - xxhash==3.0.0 123 | - yarl==1.7.2 124 | - libgcc-ng 125 | - libgomp 126 | - libnsl 127 | - libstdcxx-ng 128 | prefix: /opt/conda/envs/pol 129 | -------------------------------------------------------------------------------- /pretraining/pol-finetuning-main/scripts/run_casehold.sh: -------------------------------------------------------------------------------- 1 | export WANDB_PROJECT=pol 2 | TASK='case_hold' 3 | SEED=1 4 | 5 | 6 | MODEL='roberta-base' 7 | MODEL_NAME_OR_PATH=${MODEL} 8 | 9 | 10 | 11 | python tasks/casehold.py \ 12 | --model_name_or_path ${MODEL_NAME_OR_PATH} \ 13 | --task_name ${TASK} \ 14 | --do_train \ 15 | --do_eval \ 16 | --evaluation_strategy='steps' \ 17 | --eval_steps 500 \ 18 | --save_strategy='steps' \ 19 | --save_steps 0 \ 20 | --logging_steps 500 \ 21 | --max_seq_length 512 \ 22 | --per_device_train_batch_size=2 \ 23 | --per_device_eval_batch_size=64 \ 24 | --learning_rate=1e-5 \ 25 | --num_train_epochs=7 \ 26 | --seed ${SEED} \ 27 | --fp16 \ 28 | --report_to=wandb \ 29 | --run_name=${TASK}/${MODEL}/seed_${SEED} \ 30 | --output_dir=logs/${TASK}/${MODEL}/lr=1e-5/seed_${SEED} \ 31 | --overwrite_output_dir -------------------------------------------------------------------------------- /pretraining/pol-finetuning-main/tasks/adam_bias_correct.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Callable, Iterable, Tuple 3 | 4 | import torch 5 | from torch import nn 6 | 7 | from transformers.optimization import AdamW 8 | 9 | class AdamWBC(AdamW): 10 | """ 11 | Implements Adam algorithm with bias correction for BERT, helps with training stability for large models 12 | Parameters: 13 | params (`Iterable[nn.parameter.Parameter]`): 14 | Iterable of parameters to optimize or dictionaries defining parameter groups. 15 | lr (`float`, *optional*, defaults to 1e-3): 16 | The learning rate to use. 17 | betas (`Tuple[float,float]`, *optional*, defaults to (0.9, 0.999)): 18 | Adam's betas parameters (b1, b2). 19 | eps (`float`, *optional*, defaults to 1e-6): 20 | Adam's epsilon for numerical stability. 21 | weight_decay (`float`, *optional*, defaults to 0): 22 | Decoupled weight decay to apply. 23 | correct_bias (`bool`, *optional*, defaults to `True`): 24 | Whether or not to correct bias in Adam (for instance, in Bert TF repository they use `False`). 25 | no_deprecation_warning (`bool`, *optional*, defaults to `False`): 26 | A flag used to disable the deprecation warning (set to `True` to disable the warning). 27 | """ 28 | 29 | def __init__( 30 | self, 31 | params: Iterable[nn.parameter.Parameter], 32 | lr: float = 1e-3, 33 | betas: Tuple[float, float] = (0.9, 0.999), 34 | eps: float = 1e-6, 35 | weight_decay: float = 0.0, 36 | correct_bias: bool = True, 37 | no_deprecation_warning: bool = False, 38 | ): 39 | super().__init__(params, lr, betas, eps, weight_decay, correct_bias, no_deprecation_warning) 40 | 41 | def step(self, closure: Callable = None): 42 | """ 43 | Performs a single optimization step. 44 | Arguments: 45 | closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss. 46 | """ 47 | loss = None 48 | if closure is not None: 49 | loss = closure() 50 | 51 | for group in self.param_groups: 52 | for p in group["params"]: 53 | if p.grad is None: 54 | continue 55 | grad = p.grad.data 56 | if grad.is_sparse: 57 | raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead") 58 | 59 | state = self.state[p] 60 | 61 | # State initialization 62 | if len(state) == 0: 63 | state["step"] = 0 64 | # Exponential moving average of gradient values 65 | state["exp_avg"] = torch.zeros_like(p.data) 66 | # Exponential moving average of squared gradient values 67 | state["exp_avg_sq"] = torch.zeros_like(p.data) 68 | 69 | exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] 70 | beta1, beta2 = group["betas"] 71 | 72 | state["step"] += 1 73 | 74 | # Decay the first and second moment running average coefficient 75 | # In-place operations to update the averages at the same time 76 | exp_avg.mul_(beta1).add_(grad, alpha=(1.0 - beta1)) 77 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2) 78 | denom = exp_avg_sq.sqrt().add_(group["eps"]) 79 | 80 | step_size = group["lr"] 81 | # Bias correction for Bert 82 | bias_correction1 = 1.0 - beta1 ** state["step"] 83 | bias_correction2 = 1.0 - beta2 ** state["step"] 84 | step_size = step_size * math.sqrt(bias_correction2) / bias_correction1 85 | 86 | p.data.addcdiv_(exp_avg, denom, value=-step_size) 87 | 88 | # Just adding the square of the weights to the loss function is *not* 89 | # the correct way of using L2 regularization/weight decay with Adam, 90 | # since that will interact with the m and v parameters in strange ways. 91 | # 92 | # Instead we want to decay the weights in a manner that doesn't interact 93 | # with the m/v parameters. This is equivalent to adding the square 94 | # of the weights to the loss with plain (non-momentum) SGD. 95 | # Add weight decay at the end (fixed version) 96 | if group["weight_decay"] > 0.0: 97 | p.data.add_(p.data, alpha=(-group["lr"] * group["weight_decay"])) 98 | 99 | return loss -------------------------------------------------------------------------------- /pretraining/pol-finetuning-main/tasks/casehold_helpers.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from dataclasses import dataclass 4 | from enum import Enum 5 | from typing import List, Optional 6 | 7 | import tqdm 8 | import re 9 | 10 | from filelock import FileLock 11 | from transformers import PreTrainedTokenizer, is_tf_available, is_torch_available 12 | import datasets 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | @dataclass(frozen=True) 18 | class InputFeatures: 19 | """ 20 | A single set of features of data. 21 | Property names are the same names as the corresponding inputs to a model. 22 | """ 23 | 24 | input_ids: List[List[int]] 25 | attention_mask: Optional[List[List[int]]] 26 | token_type_ids: Optional[List[List[int]]] 27 | label: Optional[int] 28 | 29 | 30 | class Split(Enum): 31 | train = "train" 32 | dev = "dev" 33 | test = "test" 34 | 35 | 36 | if is_torch_available(): 37 | import torch 38 | from torch.utils.data.dataset import Dataset 39 | 40 | class MultipleChoiceDataset(Dataset): 41 | """ 42 | PyTorch multiple choice dataset class 43 | """ 44 | 45 | features: List[InputFeatures] 46 | 47 | def __init__( 48 | self, 49 | tokenizer: PreTrainedTokenizer, 50 | task: str, 51 | max_seq_length: Optional[int] = None, 52 | overwrite_cache=False, 53 | mode: Split = Split.train, 54 | ): 55 | dataset = datasets.load_dataset('lex_glue', task) 56 | tokenizer_name = re.sub('[^a-z]+', ' ', tokenizer.name_or_path).title().replace(' ', '') 57 | cached_features_file = os.path.join( 58 | '.cache', 59 | task, 60 | "cached_{}_{}_{}_{}".format( 61 | mode.value, 62 | tokenizer_name, 63 | str(max_seq_length), 64 | task, 65 | ), 66 | ) 67 | 68 | # Make sure only the first process in distributed training processes the dataset, 69 | # and the others will use the cache. 70 | lock_path = cached_features_file + ".lock" 71 | if not os.path.exists(os.path.join('.cache', task)): 72 | if not os.path.exists('.cache'): 73 | os.mkdir('.cache') 74 | os.mkdir(os.path.join('.cache', task)) 75 | with FileLock(lock_path): 76 | 77 | if os.path.exists(cached_features_file) and not overwrite_cache: 78 | logger.info(f"Loading features from cached file {cached_features_file}") 79 | self.features = torch.load(cached_features_file) 80 | else: 81 | logger.info(f"Creating features from dataset file at {task}") 82 | if mode == Split.dev: 83 | examples = dataset['validation'] 84 | elif mode == Split.test: 85 | examples = dataset['test'] 86 | elif mode == Split.train: 87 | examples = dataset['train'] 88 | logger.info("Training examples: %s", len(examples)) 89 | self.features = convert_examples_to_features( 90 | examples, 91 | max_seq_length, 92 | tokenizer, 93 | ) 94 | logger.info("Saving features into cached file %s", cached_features_file) 95 | torch.save(self.features, cached_features_file) 96 | 97 | def __len__(self): 98 | return len(self.features) 99 | 100 | def __getitem__(self, i) -> InputFeatures: 101 | return self.features[i] 102 | 103 | 104 | if is_tf_available(): 105 | import tensorflow as tf 106 | 107 | class TFMultipleChoiceDataset: 108 | """ 109 | TensorFlow multiple choice dataset class 110 | """ 111 | 112 | features: List[InputFeatures] 113 | 114 | def __init__( 115 | self, 116 | tokenizer: PreTrainedTokenizer, 117 | task: str, 118 | max_seq_length: Optional[int] = 256, 119 | overwrite_cache=False, 120 | mode: Split = Split.train, 121 | ): 122 | dataset = datasets.load_dataset('lex_glue') 123 | 124 | logger.info(f"Creating features from dataset file at {task}") 125 | if mode == Split.dev: 126 | examples = dataset['validation'] 127 | elif mode == Split.test: 128 | examples = dataset['test'] 129 | else: 130 | examples = dataset['train'] 131 | logger.info("Training examples: %s", len(examples)) 132 | 133 | self.features = convert_examples_to_features( 134 | examples, 135 | max_seq_length, 136 | tokenizer, 137 | ) 138 | 139 | def gen(): 140 | for (ex_index, ex) in tqdm.tqdm(enumerate(self.features), desc="convert examples to features"): 141 | if ex_index % 10000 == 0: 142 | logger.info("Writing example %d of %d" % (ex_index, len(examples))) 143 | 144 | yield ( 145 | { 146 | "input_ids": ex.input_ids, 147 | "attention_mask": ex.attention_mask, 148 | "token_type_ids": ex.token_type_ids, 149 | }, 150 | ex.label, 151 | ) 152 | 153 | self.dataset = tf.data.Dataset.from_generator( 154 | gen, 155 | ( 156 | { 157 | "input_ids": tf.int32, 158 | "attention_mask": tf.int32, 159 | "token_type_ids": tf.int32, 160 | }, 161 | tf.int64, 162 | ), 163 | ( 164 | { 165 | "input_ids": tf.TensorShape([None, None]), 166 | "attention_mask": tf.TensorShape([None, None]), 167 | "token_type_ids": tf.TensorShape([None, None]), 168 | }, 169 | tf.TensorShape([]), 170 | ), 171 | ) 172 | 173 | def get_dataset(self): 174 | self.dataset = self.dataset.apply(tf.data.experimental.assert_cardinality(len(self.features))) 175 | 176 | return self.dataset 177 | 178 | def __len__(self): 179 | return len(self.features) 180 | 181 | def __getitem__(self, i) -> InputFeatures: 182 | return self.features[i] 183 | 184 | 185 | def convert_examples_to_features( 186 | examples: datasets.Dataset, 187 | max_length: int, 188 | tokenizer: PreTrainedTokenizer, 189 | ) -> List[InputFeatures]: 190 | """ 191 | Loads a data file into a list of `InputFeatures` 192 | """ 193 | features = [] 194 | for (ex_index, example) in tqdm.tqdm(enumerate(examples), desc="convert examples to features"): 195 | if ex_index % 10000 == 0: 196 | logger.info("Writing example %d of %d" % (ex_index, len(examples))) 197 | choices_inputs = [] 198 | for ending_idx, ending in enumerate(example['endings']): 199 | context = example['context'] 200 | inputs = tokenizer( 201 | context, 202 | ending, 203 | add_special_tokens=True, 204 | max_length=max_length, 205 | padding="max_length", 206 | truncation=True, 207 | ) 208 | 209 | choices_inputs.append(inputs) 210 | 211 | label = example['label'] 212 | 213 | input_ids = [x["input_ids"] for x in choices_inputs] 214 | attention_mask = ( 215 | [x["attention_mask"] for x in choices_inputs] if "attention_mask" in choices_inputs[0] else None 216 | ) 217 | token_type_ids = ( 218 | [x["token_type_ids"] for x in choices_inputs] if "token_type_ids" in choices_inputs[0] else None 219 | ) 220 | 221 | features.append( 222 | InputFeatures( 223 | input_ids=input_ids, 224 | attention_mask=attention_mask, 225 | token_type_ids=token_type_ids, 226 | label=label, 227 | ) 228 | ) 229 | 230 | for f in features[:2]: 231 | logger.info("*** Example ***") 232 | logger.info("feature: %s" % f) 233 | 234 | return features 235 | -------------------------------------------------------------------------------- /privacy/README.md: -------------------------------------------------------------------------------- 1 | # Privacy Experiments 2 | 3 | ## EOIR 4 | 5 | The EOIR experiments and links to data/models is available in the eoir folder. 6 | 7 | 8 | ## Jane Doe 9 | 10 | The Jane Doe experiments can be found in the Jane Doe Folder. 11 | 12 | 13 | -------------------------------------------------------------------------------- /privacy/eoir/README.md: -------------------------------------------------------------------------------- 1 | 2 | # EOIR Pseudonyms Experiment 3 | 4 | In this experiment we first create a dataset of paragraphs from: 5 | 6 | ``` 7 | python create_pseodonyms_dataset.py 8 | ``` 9 | 10 | We upload the generated dataset linking data to labels in [this HuggingFace Repo](https://huggingface.co/datasets/pile-of-law/eoir_privacy). 11 | 12 | Then we use a Colab notebook to train a distillbert model, also available as an ipython notebook: 13 | 14 | ``` 15 | EOIR.ipynb 16 | ``` 17 | 18 | The resulting model is a runable model which we [also upload to HF](https://huggingface.co/pile-of-law/distilbert-base-uncased-finetuned-eoir_privacy). 19 | 20 | Then we run a perturbation experiment via the script as seen in 21 | 22 | ``` 23 | EOIR_validation_exp.ipynb 24 | ``` 25 | 26 | For the causal lexicon experiment we use: 27 | 28 | ``` 29 | causal_exp.py 30 | ``` 31 | 32 | -------------------------------------------------------------------------------- /privacy/eoir/causal_exp.py: -------------------------------------------------------------------------------- 1 | import json 2 | from datasets import load_dataset 3 | import spacy 4 | from citeurl import Citator 5 | citator = Citator() 6 | 7 | from collections import Counter 8 | def is_number(s): 9 | try: 10 | float(s) 11 | return True 12 | except ValueError: 13 | return False 14 | 15 | def build_vocab(texts, max_vocab=10000, min_freq=3): 16 | nlp = spacy.load("en_core_web_sm") # just the tokenizer 17 | wc = Counter() 18 | for doc in nlp.pipe(texts): 19 | citations = citator.list_cites(doc.text) 20 | for x in citations: 21 | wc[str(x).lower()] += 1 22 | for word in doc: 23 | if word.ent_type_ == "CARDINAL": 24 | continue 25 | if is_number(word.text): 26 | continue 27 | if len(word.text.strip()) <= 2: 28 | continue 29 | wc[word.lower_] += 1 30 | 31 | word2id = {} 32 | id2word = {} 33 | for word, count in wc.most_common(): 34 | if count < min_freq: break 35 | if len(word2id) >= max_vocab: break 36 | wid = len(word2id) 37 | word2id[word] = wid 38 | id2word[wid] = word 39 | return list([x for x in word2id.keys()]) 40 | 41 | process = True 42 | if process: 43 | dataset = load_dataset("pile-of-law/eoir_privacy", "all", split="all", use_auth_token=True) 44 | 45 | data = {} 46 | 47 | dataset.to_csv("descriptions.csv") 48 | 49 | import pandas as pd 50 | x = pd.read_csv("descriptions.csv") 51 | pseudo_label = { 0 : "no_pseudo", 1 : "pseudo"} 52 | x["label"] = [pseudo_label[z] for z in x["label"]] 53 | x = x.drop_duplicates("text") 54 | pres = pd.read_csv("presidents.csv") 55 | 56 | year_to_pres_map = {} 57 | for name, years in zip(pres["President Name"], pres["Years In Office"]): 58 | splitted_years = years.split("-") 59 | if splitted_years[-1].strip() == "": 60 | splitted_years = [splitted_years[0]] 61 | if len(splitted_years) > 1: 62 | for i in range(int(splitted_years[0]), int(splitted_years[1])+1): 63 | year_to_pres_map[i] = name 64 | else: 65 | year_to_pres_map[int(splitted_years[0])] = name 66 | 67 | x["president"] = [year_to_pres_map[int(year)] if is_number(year) and int(year) in year_to_pres_map.keys() else "UNK" for year in x["year"]] 68 | x = x[[pres != "UNK" for pres in x["president"]]] 69 | x = x[[True if is_number(year) else False for year in x["year"]]] 70 | 71 | x.to_csv("descriptions.csv") 72 | 73 | vocab = build_vocab(dataset["text"]) 74 | with open("vocab.txt", "w") as f: 75 | for v in vocab: 76 | f.write(v + "\n") 77 | 78 | with open("vocab.txt", "r") as f: 79 | vocab = [x.strip() for x in f.readlines()] 80 | 81 | print("Finished vocab!") 82 | import causal_attribution 83 | merged_scores = {} 84 | for method in ["residualization"]: 85 | for hidden_size in [512]: 86 | for lr in [7e-4]: 87 | importance_scores = causal_attribution.score_vocab( 88 | vocab=vocab, 89 | scoring_model=method, 90 | hidden_size = hidden_size, 91 | lr=lr, 92 | train_steps = 3000, 93 | max_seq_len=750, 94 | status_bar=True, 95 | use_gpu=True, 96 | csv="descriptions.csv", 97 | delimiter=",", 98 | name_to_type={ 99 | 'text': 'input', 100 | #'name' : 'control', 101 | #'president' : 'control', 102 | 'year' : 'control', 103 | 'label': 'predict' 104 | }) 105 | for (key, value) in importance_scores["label"]["pseudo"]: 106 | if key not in merged_scores: 107 | merged_scores[key] = [] 108 | merged_scores[key].append(value) 109 | 110 | averaged = {} 111 | import numpy as np 112 | for key, val in merged_scores.items(): 113 | averaged[key] = np.mean(val) 114 | 115 | with open("averaged_results.json", "w") as f: 116 | f.write(json.dumps(averaged)) 117 | import pdb; pdb.set_trace() 118 | -------------------------------------------------------------------------------- /privacy/eoir/create_pseudonyms_dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import syntok.segmenter as segmenter 3 | import re 4 | import random 5 | try: 6 | import lzma as xz 7 | except ImportError: 8 | import pylzma as xz 9 | 10 | with open('../../dataset_creation/eoir/cache/train.eoir_pseudonym.jsonl', 'r') as f: 11 | data = [json.loads(x) for x in f.readlines()] 12 | 13 | import spacy 14 | from collections import Counter 15 | 16 | import re 17 | nlp = spacy.load("en_core_web_sm") 18 | replacement = re.compile(r"(the )?(respondent|applicant|defendant|plaintiff|petitioner)(s)?", re.IGNORECASE) 19 | def paragraphs(document): 20 | start = 0 21 | for token in document: 22 | if token.is_space and token.text.count("\n") > 1: 23 | yield document[start:token.i] 24 | start = token.i 25 | yield document[start:] 26 | 27 | with xz.open("./train.privacy.eoir.jsonl.xz", "wt") as f1: 28 | with xz.open("./validation.privacy.eoir.jsonl.xz", "wt") as f2: 29 | for datapoint in data: 30 | doc = nlp(datapoint["text"]) 31 | for paragraph in paragraphs(doc): 32 | if len(paragraph.text) < 700: 33 | continue 34 | if ", Respondent" in paragraph.text: 35 | continue 36 | # TODO: better paragraph splitting 37 | text = paragraph.text 38 | text = text.replace("\n", " ") 39 | 40 | for ent in paragraph.ents: 41 | if ent.label_ == "PERSON": 42 | if not ("v." in text[max(0, ent.start-5):min(ent.end + 5, len(text))].lower()) or ("in re" in text[max(0, ent.start-5):min(ent.end + 5, len(text))].lower()): 43 | swapped_para = replacement.sub("[MASK]", text) 44 | 45 | if "[MASK]" not in swapped_para: 46 | continue 47 | 48 | print(datapoint["name"]) 49 | new_data = { 50 | "text" : swapped_para, 51 | "label" : datapoint["is_pseudonym"], 52 | "year" : datapoint["created_timestamp"].split(".")[-1], 53 | "name" : datapoint["name"] 54 | } 55 | 56 | if random.random() < .15: 57 | f2.write(json.dumps(new_data) + "\n") 58 | else: 59 | f1.write(json.dumps(new_data) + "\n") -------------------------------------------------------------------------------- /privacy/janedoe/README.md: -------------------------------------------------------------------------------- 1 | # Jane Doe Privacy Experiment 2 | 3 | To run the positive example experiment run jane_doe.py. Data is linked in the associated code. 4 | 5 | To run the negative example experiment (where Jane Doe pseudonym is not needed), run jane_doe_negative.py 6 | 7 | Note: you have to change the model name inline to switch between bert and pol-bert. 8 | 9 | 10 | -------------------------------------------------------------------------------- /privacy/janedoe/jane_doe.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pandas as pd 3 | import mxnet as mx 4 | import re 5 | import names 6 | import numpy as np 7 | import json 8 | from mlm.scorers import MLMScorer, MLMScorerPT, LMScorer 9 | from mlm.models import get_pretrained 10 | 11 | import numpy as np 12 | 13 | sheet_id = "1zf2xfYJ0dvmSVFHUATvDXqG8TFX3Gmt1CppiWTyscg0" 14 | sheet_name = "Sheet1" 15 | url = f"https://docs.google.com/spreadsheets/d/{sheet_id}/gviz/tq?tqx=out:csv&sheet={sheet_name}" 16 | 17 | df = pd.read_csv(url, on_bad_lines='skip') 18 | 19 | df["Text"].dropna() 20 | 21 | 22 | 23 | ctxs = [mx.gpu(0)] # or, e.g., [mx.gpu(0), mx.gpu(1)] 24 | model_name = 'pile-of-law/legalbert-large-1.7M-2' 25 | model, vocab, tokenizer = get_pretrained(ctxs, model_name) 26 | scorer = MLMScorerPT(model, vocab, tokenizer, ctxs) 27 | 28 | insensitive_hippo = re.compile(re.escape('jane doe'), re.IGNORECASE) 29 | insensitive_hippo2 = re.compile(re.escape(' doe '), re.IGNORECASE) 30 | insensitive_hippo3 = re.compile(re.escape('jane roe'), re.IGNORECASE) 31 | insensitive_hippo4 = re.compile(re.escape(' roe '), re.IGNORECASE) 32 | 33 | average_diff = [] 34 | 35 | for sent in list(df["Text"].dropna()): 36 | replacement_sentences = [sent] 37 | for i in range(2): 38 | new_name = names.get_full_name(gender="female") 39 | new_sent = insensitive_hippo.sub(new_name, sent) 40 | new_sent = insensitive_hippo2.sub(" " + new_name + " ", new_sent) 41 | new_sent = insensitive_hippo3.sub(new_name, new_sent) 42 | new_sent = insensitive_hippo4.sub(" " + new_name + " ", new_sent) 43 | replacement_sentences.append(new_sent) 44 | 45 | torch.cuda.empty_cache() 46 | with torch.no_grad(): 47 | scores = scorer.score_sentences(replacement_sentences, split_size=51) 48 | main_score = scores[0] 49 | for i, x in enumerate(scores): 50 | if i == 0: continue 51 | diff = main_score - x 52 | average_diff.append(diff) 53 | print(np.mean(average_diff)) 54 | 55 | with open(f"jane_doe_diffs_{model_name.split('/')[-1]}.json", "w") as f: 56 | f.write(json.dumps(average_diff)) 57 | 58 | -------------------------------------------------------------------------------- /privacy/janedoe/jane_doe_negative.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import json 3 | import torch 4 | import pandas as pd 5 | from mlm.scorers import MLMScorer, MLMScorerPT, LMScorer 6 | from mlm.models import get_pretrained 7 | import pickle 8 | import numpy as np 9 | 10 | # To generate pickle, look to https://github.com/microsoft/biosbias 11 | with open("BIOS.pkl", "rb") as f: 12 | sentences = pickle.load(f)[:500] 13 | 14 | import mxnet as mx 15 | import re 16 | import names 17 | import numpy as np 18 | 19 | ctxs = [mx.gpu(0)] # or, e.g., [mx.gpu(0), mx.gpu(1)] 20 | model_name = 'pile-of-law/legalbert-large-1.7M-2' # Swap model name to use different model 21 | model, vocab, tokenizer = get_pretrained(ctxs, model_name) 22 | scorer = MLMScorerPT(model, vocab, tokenizer, ctxs) 23 | 24 | import re 25 | 26 | 27 | import names 28 | average_diff = [] 29 | 30 | import random 31 | for sent in list(sentences): 32 | sent = sent["bio"] 33 | 34 | replacement_sentences = [sent.replace("_", "Jane Doe")] 35 | 36 | for i in range(2): 37 | new_name = names.get_full_name(gender="female") 38 | new_sent = sent.replace("_", new_name) 39 | replacement_sentences.append(new_sent) 40 | 41 | torch.cuda.empty_cache() 42 | with torch.no_grad(): 43 | scores = scorer.score_sentences(replacement_sentences, split_size=51) 44 | main_score = scores[0] 45 | for i, x in enumerate(scores): 46 | if i == 0: continue 47 | diff = main_score - x 48 | average_diff.append(diff) 49 | print(np.mean(average_diff)) 50 | 51 | with open(f"jane_doe_diffs_{model_name.split('/')[-1]}_negative.json", "w") as f: 52 | f.write(json.dumps(average_diff)) 53 | 54 | -------------------------------------------------------------------------------- /privacy/janedoe/jane_doe_plot.py: -------------------------------------------------------------------------------- 1 | import json 2 | import seaborn as sns 3 | import pandas as pd 4 | 5 | import matplotlib.pyplot as plt 6 | 7 | # this is the setting for plots for research paper and articles. 8 | # %matplotlib inline 9 | # %config InlineBackend.figure_format = 'retina' 10 | 11 | # Set the global font to be DejaVu Sans, size 10 (or any other sans-serif font of your choice!) 12 | plt.rc('font',**{'family':'sans-serif','sans-serif':['DejaVu Sans'],'size':16}) 13 | 14 | # Set the font used for MathJax - more on this later 15 | plt.rc('mathtext',**{'default':'regular'}) 16 | 17 | # Set the style for seaborn 18 | plt.style.use(['seaborn-whitegrid', 'seaborn-paper']) 19 | 20 | import matplotlib.pylab as pylab 21 | params = {'legend.fontsize': 'large', 22 | 'axes.labelsize': 'large', 23 | 'axes.titlesize': 'large', 24 | 'xtick.labelsize': 'large', 25 | 'ytick.labelsize': 'large' 26 | } 27 | 28 | pylab.rcParams.update(**params) 29 | 30 | import seaborn as sns 31 | sns.set_context(rc=params) 32 | 33 | def stylize_axes(ax, title): 34 | """ 35 | Stylize the axes by removing ths spines and ticks. 36 | """ 37 | # removes the top and right lines from the plot rectangle 38 | ax.spines['top'].set_visible(False) 39 | ax.spines['right'].set_visible(False) 40 | 41 | ax.xaxis.set_tick_params(top=False, direction='out', width=1) 42 | ax.yaxis.set_tick_params(right=False, direction='out', width=1) 43 | 44 | # Enforce the size of the title, label and tick labels 45 | ax.set_xlabel(ax.get_xlabel(), fontsize='large') 46 | ax.set_ylabel(ax.get_ylabel(), fontsize='large') 47 | 48 | ax.set_yticklabels(ax.get_yticklabels(), fontsize='medium') 49 | ax.set_xticklabels(ax.get_xticklabels(), fontsize='medium') 50 | 51 | ax.set_title(title, fontsize='large') 52 | 53 | def save_image(fig, title): 54 | """ 55 | Save the figure as PNG and pdf files 56 | """ 57 | if title is not None: 58 | fig.savefig(title+".png", dpi=300, bbox_inches='tight', transparent=True) 59 | fig.savefig(title+".pdf", bbox_inches='tight') 60 | 61 | def figure_size(fig, size): 62 | fig.set_size_inches(size) 63 | fig.tight_layout() 64 | 65 | def resadjust(ax, xres=None, yres=None): 66 | """ 67 | Send in an axis and fix the resolution as desired. 68 | """ 69 | 70 | if xres: 71 | start, stop = ax.get_xlim() 72 | ticks = np.arange(start, stop + xres, xres) 73 | ax.set_xticks(ticks) 74 | if yres: 75 | start, stop = ax.get_ylim() 76 | ticks = np.arange(start, stop + yres, yres) 77 | ax.set_yticks(ticks) 78 | 79 | with open("jane_doe_diffs_legalbert-large-1.7M-2.json", "r") as f: 80 | jane_doe_lb = json.load(f) 81 | with open("jane_doe_diffs_bert-large-uncased.json", "r") as f: 82 | jane_doe_b = json.load(f) 83 | with open("jane_doe_diffs_legalbert-large-1.7M-2_negative.json", "r") as f: 84 | jane_doe_lb_negjd = json.load(f) 85 | with open("jane_doe_diffs_bert-large-uncased_negative.json", "r") as f: 86 | jane_doe_b_negjd = json.load(f) 87 | 88 | 89 | import matplotlib.pyplot as plt 90 | 91 | df = pd.DataFrame.from_dict({ 92 | "Model" : (["bert"] * len(jane_doe_b)) + (["pol-bert"] * len(jane_doe_lb)) + (["bert"] * len(jane_doe_b_negjd)) + (["pol-bert"] * len(jane_doe_lb_negjd)), 93 | "Jane Doe Score" : jane_doe_b + jane_doe_lb + jane_doe_b_negjd + jane_doe_lb_negjd, 94 | "Sample Type" : ["Court Case\n(Jane Doe)"] * (len(jane_doe_b) + len(jane_doe_lb)) + ["Bios"] * (len(jane_doe_b_negjd) + len(jane_doe_lb_negjd)) 95 | }) 96 | 97 | #swarm_plot = sns.barplot(x="Model", y="Jane Doe Score", data=df) 98 | swarm_plot = sns.factorplot(x = 'Sample Type', y='Jane Doe Score', 99 | hue = 'Model', data=df, kind='bar') 100 | #fig = swarm_plot.get_figure() 101 | plt.xticks(rotation=20) 102 | 103 | save_image(swarm_plot.fig, "janedoe") 104 | import numpy as np 105 | print(np.mean(jane_doe_b) - np.mean(jane_doe_b_negjd)) 106 | print(np.mean(jane_doe_lb) - np.mean(jane_doe_lb_negjd)) -------------------------------------------------------------------------------- /scrub/README.md: -------------------------------------------------------------------------------- 1 | ## Scrub 2 | This directory contains the script, `scrub.py`, for scrubbing the dataset for sensitive information. `scrub.py` writes the scrubbed data to jsonl.xz files in the `data_scrubbed` directory and the filth data to json files in the `filth` directory. 3 | 4 | The filth data json files contain File Filth JSON objects with the following structure: 5 | ``` 6 | { 7 | filename: name of the data file (dtype: string) 8 | filth_count: total count of filth detected across all documents in the data file (dtype: integer) 9 | word_count: total count of all words across all documents in the data file (dtype: integer) 10 | filth_doc_count: total count of documents with >0 filth in the data file (dtype: integer) 11 | doc_count: total count of documents in the data file (dtype: integer) 12 | filth_data: list of Document Filth JSON objects, corresponding to each document in the data file, order of Document Filth JSON objects perserves document order from original data file (dtype: list[Document Filth JSON object]) 13 | } 14 | ``` 15 | 16 | where a Document Filth JSON object has the following structure: 17 | ``` 18 | { 19 | url: source url where document was scraped (dtype: string) 20 | filth_count: count of filth detected in document (dtype: integer) 21 | word_count: count of all words in document (dtype: integer) 22 | filth: list of Filth JSON objects, corresponding to each item of filth detected in the document (dtype: list[Filth JSON object]) 23 | } 24 | ``` 25 | 26 | and a Filth JSON object has the following structure: 27 | ``` 28 | { 29 | type: type of filth, corresponds to type field in [scrubadub.filth.Filth](https://scrubadub.readthedocs.io/en/stable/api_scrubadub_filth.html) class (dtype: string or list[string]) 30 | text: filth text / what was scrubbed out, corresponds to text field in [scrubadub.filth.Filth](https://scrubadub.readthedocs.io/en/stable/api_scrubadub_filth.html) class (dtype: string or list[string]) 31 | merged: if is MergedFilth, filth is merged when multiple types of filth are detected for overlapping text, if merged = true, type and text will be lists (that share the same order) of all of the types and corresponding texts for which filth was detected (dtype: boolean) 32 | } 33 | ``` -------------------------------------------------------------------------------- /scrub/pii.py: -------------------------------------------------------------------------------- 1 | # Detects PII 2 | import glob 3 | import os 4 | import json 5 | import numpy as np 6 | import math 7 | import argparse 8 | from presidio_analyzer import AnalyzerEngine 9 | from presidio_anonymizer import AnonymizerEngine 10 | from tqdm import tqdm 11 | from datasets import load_dataset 12 | from multiprocessing import Pool, cpu_count 13 | 14 | # DATA DIRS 15 | DATA_URL = 'pile-of-law/pile-of-law' 16 | LOG_DIR = 'logs/pii' 17 | 18 | DATASETS = [ 19 | 'bar_exam_outlines', 20 | 'fre', 21 | 'tos', 22 | 'ftc_advisory_opinions', 23 | 'frcp', 24 | 'constitutions', 25 | 'cfpb_creditcard_contracts', 26 | 'uscode', 27 | 'olc_memos', 28 | 'cc_casebooks', 29 | 'echr', 30 | 'federal_register', 31 | 'un_debates', 32 | 'euro_parl', 33 | 'scotus_oral_arguments', 34 | 'cfr', 35 | 'founding_docs', 36 | 'tax_rulings', 37 | 'r_legaladvice', 38 | 'eurlex', 39 | 'us_bills', 40 | 'nlrb_decisions', 41 | 'canadian_decisions', 42 | 'eoir', 43 | 'dol_ecab', 44 | # > 200MB 45 | 'oig', 46 | 'scotus_filings', 47 | 'state_codes', 48 | # > 1GB 49 | 'congressional_hearings', 50 | 'edgar', 51 | 'bva_opinions', 52 | 'courtlistener_docket_entry_documents', 53 | 'atticus_contracts', 54 | 'courtlistener_opinions', 55 | ] 56 | 57 | SPLITS = ['train', 'validation'] 58 | 59 | # PII entities 60 | ENTITIES = ["US_BANK_NUMBER", "US_DRIVER_LICENSE", "US_ITIN", "US_PASSPORT", "US_SSN", "CREDIT_CARD", "PHONE_NUMBER", "EMAIL_ADDRESS"] 61 | 62 | # Threshold probabilities 63 | PII_THRESHOLD_PROB = 0.5 64 | 65 | MAX_LEN = int(1000000 / 10) 66 | 67 | # In characters 68 | CONTEXT_WINDOW = 20 69 | 70 | def scrub_doc(doc): 71 | analyzer, anonymizer = init_presidio() 72 | 73 | doc_text = doc["text"] 74 | doc_words = doc_text.split() 75 | doc_word_count = len(doc_words) 76 | 77 | 78 | def detect_pii(doc_text): 79 | analyzer_results = analyzer.analyze(text=doc_text, entities=ENTITIES, language='en') 80 | # Filter results to those with score >= PII_THRESHOLD_PROB 81 | # From manual inspection, PII_THRESHOLD_PROB = 0.5 seems reasonable 82 | results = [] 83 | for result in analyzer_results: 84 | result = result.to_dict() 85 | if result['score'] >= PII_THRESHOLD_PROB: 86 | doc_len_chars = len(doc_text) 87 | context = "" 88 | if result['start'] - CONTEXT_WINDOW < 0 and result['end'] + CONTEXT_WINDOW + 1 > doc_len_chars: 89 | context = doc_text[0:doc_len] 90 | elif result['start'] - CONTEXT_WINDOW < 0: 91 | context = doc_text[0:result['end'] + CONTEXT_WINDOW + 1] 92 | elif result['end'] + CONTEXT_WINDOW + 1 > doc_len_chars: 93 | context = doc_text[result['start'] - CONTEXT_WINDOW:doc_len_chars] 94 | else: 95 | context = doc_text[result['start'] - CONTEXT_WINDOW:result['end'] + CONTEXT_WINDOW + 1] 96 | results.append({'type': result['entity_type'], 'span': doc_text[result['start']: result['end']], 'context': context, 'start': result['start'], 'end': result['end'], 'score': result['score']}) 97 | return results 98 | 99 | # Detect PII 100 | try: 101 | doc_pii = detect_pii(doc_text) 102 | except ValueError as error: 103 | print(error) 104 | 105 | n = math.ceil(len(doc_text) / MAX_LEN) 106 | print(n) 107 | chunks = [doc_text[i:i+MAX_LEN] for i in range(0, len(doc_text), MAX_LEN)] 108 | doc_pii = [] 109 | for chunk in chunks: 110 | doc_pii_chunk = detect_pii(chunk) 111 | doc_pii += doc_pii_chunk 112 | 113 | doc_pii_count = len(doc_pii) 114 | 115 | # Aggregate doc log data 116 | doc_log_data = {} 117 | doc_log_data["url"] = doc["url"] 118 | doc_log_data["word_count"] = doc_word_count 119 | doc_log_data["pii_count"] = doc_pii_count 120 | doc_log_data["pii"] = doc_pii 121 | 122 | return doc_log_data 123 | 124 | 125 | def scrub(split, name, dataset): 126 | docs_log_data = [] 127 | 128 | # Global counts across the dataset 129 | pii_count = 0 130 | word_count = 0 131 | docs_with_pii = 0 132 | doc_count = 0 133 | 134 | results = None 135 | with Pool(processes=cpu_count() - 1) as p: 136 | results = list(tqdm(p.imap(scrub_doc, dataset), total=len(dataset))) 137 | 138 | for doc_log_data in results: 139 | word_count += doc_log_data["word_count"] 140 | pii_count += doc_log_data["pii_count"] 141 | if doc_log_data["pii_count"] > 0: 142 | docs_with_pii += 1 143 | doc_count += 1 144 | 145 | docs_log_data.append(doc_log_data) 146 | 147 | dataset_log_data = {} 148 | dataset_log_data["split"] = split 149 | dataset_log_data["name"] = name 150 | dataset_log_data["pii_count"] = pii_count 151 | dataset_log_data["word_count"] = word_count 152 | dataset_log_data["docs_with_pii"] = docs_with_pii 153 | dataset_log_data["doc_count"] = doc_count 154 | dataset_log_data["per_doc_log_data"] = docs_log_data 155 | 156 | return dataset_log_data 157 | 158 | 159 | def init_presidio(): 160 | analyzer = AnalyzerEngine() 161 | anonymizer = AnonymizerEngine() 162 | return analyzer, anonymizer 163 | 164 | 165 | def save_json_file(data, out_dir, filename): 166 | if not os.path.exists(out_dir): 167 | os.makedirs(out_dir) 168 | filepath = os.path.join(out_dir, filename) 169 | with open(filepath, 'w') as out_file: 170 | json.dump(data, out_file) 171 | 172 | 173 | def main(args): 174 | split = args.split 175 | name = args.name 176 | print(f"{split}.{name}") 177 | dataset = load_dataset(DATA_URL, args.name, use_auth_token=True, split=args.split, streaming=False) 178 | dataset_log_data = scrub(split, name, dataset) 179 | 180 | # Save dataset log data (statistics on PII) to json 181 | save_json_file(dataset_log_data, LOG_DIR, f'{split}.{name}.json') 182 | 183 | 184 | if __name__ == '__main__': 185 | parser = argparse.ArgumentParser() 186 | parser.add_argument("--split") 187 | parser.add_argument("--name") 188 | args = parser.parse_args() 189 | main(args) 190 | -------------------------------------------------------------------------------- /scrub/scrub.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import json 4 | try: 5 | import lzma as xz 6 | except ImportError: 7 | import pylzma as xz 8 | import scrubadub 9 | from tqdm import tqdm 10 | 11 | # DATA DIRS 12 | IN_DIR = 'data' 13 | OUT_DIR = 'data_scrubbed' 14 | FILTH_DIR = 'filth' 15 | 16 | # remaining 17 | TO_SCRUB = ['train.atticus_contracts.jsonl.xz', 18 | 'validation.atticus_contracts.jsonl.xz', 19 | 'train.bva.jsonl.xz', 20 | 'validation.bva.jsonl.xz', 21 | 'train.congressional_hearings.xz', 22 | 'validation.congressional_hearings.xz', 23 | 'train.courtlistenerdocketentries.xz', 24 | 'validation.courtlistenerdocketentries.xz', 25 | 'train.courtlisteneropinions.xz', 26 | 'validation.courtlisteneropinions.xz', 27 | 'train.edgar.jsonl.xz', 28 | 'validation.edgar.jsonl.xz', 29 | ] 30 | 31 | 32 | def scrub(filepath, scrubber): 33 | """This function scrubs the data file of sensitive information. 34 | Returns 35 | - A list of the documents in the data file, with text scrubbed 36 | - A dict of the file filth data 37 | """ 38 | print(f"Scrub {filepath}") 39 | 40 | print(f"Read {filepath}") 41 | lines = [] 42 | with xz.open(filepath, mode='rb') as f: 43 | while True: 44 | try: 45 | line = f.readline().decode('utf-8') 46 | if line == "": 47 | break 48 | lines.append(line) 49 | except: 50 | print("corrupted line") 51 | break 52 | 53 | 54 | print(f"Clean {filepath}") 55 | docs_scrubbed = [] 56 | docs_filth_data = [] 57 | 58 | filth_count = 0 59 | word_count = 0 60 | filth_doc_count = 0 61 | 62 | for line in tqdm(lines): 63 | if line is not None and line != "": 64 | doc = json.loads(line) 65 | doc_text = doc["text"] 66 | doc_word_count = len(doc_text.split()) 67 | word_count += doc_word_count 68 | 69 | # Detect filth 70 | filth_objs = list(scrubber.iter_filth(doc_text)) 71 | doc_filth_count = len(filth_objs) 72 | filth_count += doc_filth_count 73 | if doc_filth_count > 0: 74 | filth_doc_count += 1 75 | 76 | doc_filth = [] 77 | for filth_obj in filth_objs: 78 | filth = {} 79 | if isinstance(filth_obj, scrubadub.filth.base.MergedFilth): # filth_obj is MergedFilth 80 | types = [] 81 | texts = [] 82 | for filth_obj_i in filth_obj.filths: 83 | types.append(filth_obj_i.detector_name) 84 | texts.append(filth_obj_i.text) 85 | filth["type"] = types 86 | filth["text"] = texts 87 | filth["merged_text"] = filth_obj.text 88 | filth["merged"] = True 89 | else: # filth_obj is Filth 90 | filth["type"] = filth_obj.detector_name 91 | filth["text"] = filth_obj.text 92 | filth["merged"] = False 93 | doc_filth.append(filth) 94 | 95 | doc_filth_data = {} 96 | doc_filth_data["url"] = doc["url"] 97 | doc_filth_data["filth_count"] = doc_filth_count 98 | doc_filth_data["word_count"] = doc_word_count 99 | doc_filth_data["filth"] = doc_filth 100 | 101 | # Clean 102 | cleaned_text = scrubber.clean(doc_text) 103 | doc["text"] = cleaned_text 104 | 105 | docs_scrubbed.append(doc) 106 | docs_filth_data.append(doc_filth_data) 107 | 108 | file_filth_data = {} 109 | file_filth_data["filename"] = filepath.split('/')[-1] 110 | file_filth_data["filth_count"] = filth_count 111 | file_filth_data["word_count"] = word_count 112 | file_filth_data["filth_doc_count"] = filth_doc_count 113 | file_filth_data["doc_count"] = len(lines) 114 | file_filth_data["filth_data"] = docs_filth_data 115 | 116 | return docs_scrubbed, file_filth_data 117 | 118 | 119 | def init_scrubber(): 120 | """Initialize scrubber with detectors.""" 121 | detector_list = [scrubadub.detectors.en_US.SocialSecurityNumberDetector, 122 | scrubadub.detectors.PhoneDetector, 123 | scrubadub.detectors.EmailDetector, 124 | scrubadub.detectors.CreditCardDetector 125 | ] 126 | scrubber = scrubadub.Scrubber(detector_list=detector_list, locale='en_US') 127 | return scrubber 128 | 129 | 130 | def save_compressed_file(data, out_dir, filename): 131 | if not os.path.exists(out_dir): 132 | os.makedirs(out_dir) 133 | filepath = os.path.join(out_dir, filename) 134 | with xz.open(filepath, 'w') as out_file: 135 | for x in data: 136 | out_file.write((json.dumps(x) + "\n").encode("utf-8")) 137 | 138 | 139 | def save_json_file(data, out_dir, filename): 140 | if not os.path.exists(out_dir): 141 | os.makedirs(out_dir) 142 | filepath = os.path.join(out_dir, filename) 143 | with open(filepath, 'w') as out_file: 144 | json.dump(data, out_file) 145 | 146 | 147 | def main(): 148 | filepaths_all = glob.glob(os.path.join(IN_DIR, "*.xz")) 149 | filepaths_to_scrub = [filepath for filepath in filepaths_all if filepath.split('/')[-1] in TO_SCRUB] 150 | 151 | scrubber = init_scrubber() 152 | for filepath in filepaths_to_scrub: 153 | docs_scrubbed, filth_data = scrub(filepath, scrubber) 154 | 155 | filename = filepath.split('/')[-1] 156 | trunc_filename = ".".join(filename.split('.')[0:2]) 157 | 158 | # Save cleaned docs as compressed xz file 159 | save_compressed_file(docs_scrubbed, OUT_DIR, filename) 160 | 161 | # Save file filth data to json 162 | save_json_file(filth_data, FILTH_DIR, trunc_filename + ".json") 163 | 164 | 165 | if __name__ == '__main__': 166 | main() -------------------------------------------------------------------------------- /toxicity/README.md: -------------------------------------------------------------------------------- 1 | 2 | # Processing Supreme Court Opinions with Toxicity APIs 3 | 4 | The scotus_only* files run toxicity scores over only scotus documents. Note: for this we use the case law access project supreme court decisions data because it contains accurate metadata for opinion dates. To run this, you will need to also download the bulk scotus opinions from CAP. The output of these scripts, a sentence by sentence scoring of all Supreme Court opinions can be found here: 5 | 6 | https://drive.google.com/drive/folders/1QvLdlBGHZYX6mv5HCh1SL_QZLCs5yN6-?usp=sharing 7 | 8 | This is roughly 11Gb of data. 9 | 10 | The data is formatted such that each row is a sentence indexed by document and a score from a different API. The mapping from sentence and document indices can be found in sent_mapping.pkl 11 | 12 | # Figure 2 13 | 14 | The Figure 2 generation code can be found in fig2 15 | 16 | -------------------------------------------------------------------------------- /toxicity/create_doc_sent_index.py: -------------------------------------------------------------------------------- 1 | # Detects toxic language with unitariyai/detoxify 2 | from collections import defaultdict 3 | from tqdm.contrib.concurrent import process_map # or thread_map 4 | from multiprocess import set_start_method 5 | set_start_method('spawn', force=True) 6 | 7 | 8 | from collections import defaultdict 9 | import csv 10 | import glob 11 | import os 12 | import json 13 | import numpy as np 14 | import math 15 | import random 16 | import argparse 17 | from detoxify import Detoxify 18 | from tqdm import tqdm 19 | from datasets import load_dataset 20 | from multiprocessing import Pool, cpu_count 21 | import lexnlp.nlp.en.segments.sentences as lexnlp 22 | from googleapiclient import discovery 23 | import time 24 | 25 | from sys import getsizeof 26 | 27 | from transformers import pipeline 28 | 29 | base_dir = "./" 30 | 31 | PIPE = None 32 | 33 | DATA_URL = 'pile-of-law/pile-of-law' 34 | 35 | 36 | SPLITS = ['train', 'validation'] 37 | 38 | # Threshold probabilities 39 | PROFANITY_THRESHOLD_PROB = 0.8 40 | 41 | MAX_LEN = int(1000000 / 10) 42 | 43 | 44 | def chunks(lst, n): 45 | """Yield successive n-sized chunks from lst.""" 46 | for i in range(0, len(lst), n): 47 | yield lst[i:i + n] 48 | 49 | 50 | def save_json_file(data, out_dir, filename): 51 | if not os.path.exists(out_dir): 52 | os.makedirs(out_dir) 53 | filepath = os.path.join(out_dir, filename) 54 | with open(filepath, 'w') as out_file: 55 | json.dump(data, out_file) 56 | 57 | class NumpyEncoder(json.JSONEncoder): 58 | def default(self, obj): 59 | if isinstance(obj, np.ndarray): 60 | return obj.tolist() 61 | return json.JSONEncoder.default(self, obj) 62 | 63 | def get_opinions(stuff): 64 | stuff = json.loads(stuff) 65 | return { 66 | "text" : "\n".join(x["text"] for x in stuff["casebody"]["data"]["opinions"]), 67 | "decision_date" : stuff["decision_date"], 68 | "name" : stuff["name"], 69 | "docket_number" : stuff["docket_number"], 70 | "citations" : [x["cite"] for x in stuff["citations"]] 71 | } 72 | 73 | def main(args): 74 | split = args.split 75 | name = args.name 76 | print(f"{split}.{name}") 77 | results = [] 78 | with open("data.jsonl", "r") as f: 79 | opinions = f.readlines() 80 | opinions = [get_opinions(x) for x in opinions] 81 | opinions = [x for x in opinions if len(x["text"]) > 3000] 82 | dates = [x["decision_date"] for x in opinions] 83 | names = [x["name"] for x in opinions] 84 | 85 | from torch.utils.data import Dataset 86 | 87 | class MyDataset(Dataset): 88 | 89 | def __init__(self) -> None: 90 | super().__init__() 91 | self.data = defaultdict(dict) 92 | for doc_idx, opinion in enumerate(opinions): 93 | sentences = lexnlp.get_sentence_list(opinion["text"]) 94 | for sentence_idx, sentence in enumerate(sentences): 95 | self.data[doc_idx][sentence_idx] = { 96 | "sent" : sentence, 97 | "name" : opinion["name"], 98 | "docket_number" : opinion["docket_number"], 99 | "decision_date" : opinion["decision_date"], 100 | "citations" : opinion["citations"] 101 | } 102 | 103 | def __len__(self): 104 | return len(self.data) 105 | 106 | def __getitem__(self, i): 107 | return self.data[i]["text"] 108 | 109 | 110 | dataset = MyDataset() 111 | import pickle 112 | with open("sent_mapping.pkl", "wb") as f: 113 | pickle.dump(dataset.data, f) 114 | 115 | 116 | 117 | 118 | 119 | 120 | if __name__ == '__main__': 121 | parser = argparse.ArgumentParser() 122 | parser.add_argument("--split") 123 | parser.add_argument("--name") 124 | 125 | args = parser.parse_args() 126 | main(args) 127 | -------------------------------------------------------------------------------- /toxicity/fig2/code/fig2.R: -------------------------------------------------------------------------------- 1 | ######################################### 2 | #### Pile of Law, Figure 2 #### 3 | # Last updated: June 28 # 4 | ######################################### 5 | 6 | #### SETUP #### 7 | # Load software 8 | rm(list=ls()) 9 | library(psych) 10 | library(tidyverse) 11 | library(ggplot2) 12 | library(stringr) 13 | library(gridExtra) 14 | 15 | # Set root 16 | root <- "/Users/MarkKrass/Downloads/" 17 | 18 | ### Import Toxicity Scores 19 | # Get data: Profanity Checker 20 | pc <- read.csv(paste0(root, "pc_only_scores.csv"), 21 | colClasses = c("numeric","numeric", 22 | "factor","factor", 23 | "numeric","character","factor")) %>% 24 | mutate(year = as.numeric(str_extract(Date, "^[[:digit:]]{4}"))) %>% 25 | select(-Date) 26 | 27 | # Get data: Toxigen 28 | tx <- read.csv("/scratch/users/mkrass/pile/perspective_scores_toxigen_only.csv", 29 | stringsAsFactors = T) 30 | 31 | # Get data: Detoxify 32 | dt.max <- read.csv("/scratch/users/mkrass/pile/detoxify_scores.csv", 33 | stringsAsFactors = T )%>% 34 | group_by(DocIndex, SentenceIndex) %>% 35 | summarize(Score = max(Score)) 36 | 37 | # Get data: Perspective 38 | pp <- read.csv("/scratch/users/mkrass/pile/perspective_scores_pc_only.csv", 39 | stringsAsFactors = T) %>% 40 | filter(Category %in% toupper(tox_classes)) %>% 41 | group_by(DocIndex, SentenceIndex) %>% 42 | summarize(Score = max(Score)) 43 | 44 | ### Merge Toxicity Score 45 | pc <- tx %>% select(DocIndex,SentenceIndex,Score) %>% rename(toxigen=Score) %>% 46 | right_join(pc) 47 | pc <- dt.max %>% select(DocIndex,SentenceIndex,Score) %>% rename(detoxify=Score) %>% 48 | right_join(pc) 49 | pc <- pp %>% select(DocIndex,SentenceIndex,Score) %>% rename(perspective=Score) %>% 50 | right_join(pc) 51 | 52 | 53 | ### Load the Supreme Court Database (Spaeth, Epstein et al. 2021) 54 | load(paste0(root, "SCDB_2021_01_caseCentered_Citation.Rdata")) 55 | load(paste0(root, "SCDB_Legacy_07_caseCentered_Citation.Rdata")) 56 | 57 | # Get crosswalk of case names 58 | scdb.xw <- read.csv(paste0(root,"scdb_crosswalk.csv"), stringsAsFactors = T) %>% 59 | rename(Name=name) %>% 60 | mutate(year = as.numeric(substr(decision_date, start=1,stop=4))) 61 | 62 | # Rename 63 | scdb <- SCDB_2021_01_caseCentered_Citation %>% select(dateDecision, caseName, usCite, issueArea) ; rm(SCDB_2021_01_caseCentered_Citation) 64 | scdb.legacy <- SCDB_Legacy_07_caseCentered_Citation %>% select(dateDecision, caseName, usCite, issueArea); rm(SCDB_Legacy_07_caseCentered_Citation) 65 | 66 | 67 | pc <- pc %>% 68 | left_join(scdb.xw[,c("Name","year","usr","scdb")]) %>% 69 | distinct(DocIndex,SentenceIndex,.keep_all=T) 70 | 71 | 72 | pc <- pc %>% rename(caseId = scdb) %>% left_join(scdb) 73 | 74 | pcm.cr.long <- pc %>% 75 | select(year,DocIndex, issueArea,perspective,detoxify,Score,toxigen) %>% 76 | pivot_longer(perspective:toxigen,names_to="model",values_to="score2") 77 | 78 | 79 | 80 | # Save merged 81 | pc %>% saveRDS(file=paste0(root, "merged_maxes.RDS")) 82 | 83 | # (Optional: Load merged ) 84 | #pc <- readRDS(paste0(root,"merged_maxes.RDS")) 85 | 86 | 87 | #### Obtain Cohen's K at Multiple Thresholds #### 88 | 89 | roundUp <- function(x){round(x/10)*10} 90 | out <- c() 91 | yrs <- c() 92 | w <- c() 93 | ts <- c() 94 | # Group case years into 10-year bins 95 | pc$yr_bin <- sapply(pc$year, roundUp) 96 | 97 | # At multiple thresholds, obtain Cohen's K 98 | for(t in c(0.5,0.8,0.9,0.95)){ 99 | pc <- pc %>% mutate( 100 | detoxify.yn = ifelse(detoxify>t,1,0), 101 | profanitycheck.yn = ifelse(Score>t,1,0), 102 | toxigen.yn = ifelse(toxigen>t,1,0), 103 | perspective.yn = ifelse(perspective>t,1,0), 104 | iss_desc = case_when(issueArea == 2 ~ "Civil Rights", 105 | TRUE ~ "All Others")) 106 | # Exclude early years with few data points 107 | for(y in sort(unique(pc$yr_bin))[5:26]){ 108 | # Focus on comparing Perspective and Profanity-Checker, 109 | # since these are especially popular 110 | dat <- pc %>% 111 | filter(yr_bin == y) %>% 112 | select(perspective.yn,profanitycheck.yn) %>% 113 | drop_na() %>% 114 | as.matrix() 115 | k <- cohen.kappa(x=as.matrix(dat))$kappa 116 | out <- c(out,k) 117 | yrs <- c(yrs,y) 118 | w <- c(w, nrow(dat)) 119 | ts <- c(ts, t) 120 | }} 121 | 122 | # Collect data 123 | cpdat <- data.frame(k = out, year=yrs, w=w, ts=ts) 124 | 125 | #### Plot #### 126 | 127 | # Left panel: Cohen's K 128 | cohens <- cpdat %>% 129 | # Remove points with too few data points 130 | filter(ts %in% c(0.5),w>5000) %>% 131 | ggplot(aes(x=year,y=k)) + 132 | geom_point(aes(size=w)) + 133 | geom_line() + scale_size(guide="none") + 134 | ylab("Cohen's K") + xlim(1880,2010) 135 | 136 | 137 | # Right panel: issue scores by area over time 138 | plt.issue.area.share <- pcm.cr.long %>% 139 | mutate(iss_desc = case_when(issueArea == 2 ~ "Civil Rights", 140 | TRUE ~ "All Others"), 141 | model_desc = case_when(model == "Score" ~ "profanity_checker", 142 | TRUE ~ model), 143 | score.yn = ifelse(score2 > 0.5,1,0)) %>% 144 | group_by(year,model_desc,iss_desc) %>% 145 | summarize(score=mean(score.yn),n=n()) %>% 146 | ungroup() %>% 147 | ggplot(aes(x=year,y=score, color=model_desc)) + 148 | geom_smooth(span=0.3)+ylab("Share P(Toxic) > 0.5")+ 149 | facet_wrap(~iss_desc) + xlim(1875,2022) + theme(legend.position = "right", legend.text = element_text(size = 6), 150 | legend.title = element_text(size = 8)) + 151 | guides(color = guide_legend(override.aes = list(size = 0.5))) 152 | 153 | 154 | ggsave(plt.issue.area.share, filename = paste0(root,"share_toxic_issue.pdf"), 155 | device="pdf",width=7,height=3) 156 | 157 | 158 | ## Assemble plots 159 | assembled <- grid.arrange(cohens, plt.issue.area.share, nrow=1, widths=c(0.3,0.7)) 160 | ggsave(assembled, filename = paste0(root,"combined_cohens_toxic.pdf"), 161 | device="pdf",width=8,height=3) 162 | 163 | 164 | -------------------------------------------------------------------------------- /toxicity/fig2/data/SCDB.Rdata: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Breakend/PileOfLaw/e6f7b11ba52e0fbbd392ad00bf3cc1632ae710b6/toxicity/fig2/data/SCDB.Rdata -------------------------------------------------------------------------------- /toxicity/fig2/data/SCDBLegacy.Rdata: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Breakend/PileOfLaw/e6f7b11ba52e0fbbd392ad00bf3cc1632ae710b6/toxicity/fig2/data/SCDBLegacy.Rdata -------------------------------------------------------------------------------- /toxicity/fig2/data/SCDB_2018_02_caseCentered_Citation.csv: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Breakend/PileOfLaw/e6f7b11ba52e0fbbd392ad00bf3cc1632ae710b6/toxicity/fig2/data/SCDB_2018_02_caseCentered_Citation.csv -------------------------------------------------------------------------------- /toxicity/fig2/data/SCDB_2019_01_caseCentered_Citation.Rdata: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Breakend/PileOfLaw/e6f7b11ba52e0fbbd392ad00bf3cc1632ae710b6/toxicity/fig2/data/SCDB_2019_01_caseCentered_Citation.Rdata -------------------------------------------------------------------------------- /toxicity/scotus_only_pc_only.py: -------------------------------------------------------------------------------- 1 | # Detects toxic language with unitariyai/detoxify 2 | from collections import defaultdict 3 | from tqdm.contrib.concurrent import process_map # or thread_map 4 | 5 | 6 | import csv 7 | import glob 8 | import os 9 | import json 10 | import numpy as np 11 | import math 12 | import random 13 | import argparse 14 | from profanity_check import predict_prob 15 | from detoxify import Detoxify 16 | from tqdm import tqdm 17 | from datasets import load_dataset 18 | from multiprocessing import Pool, cpu_count 19 | import lexnlp.nlp.en.segments.sentences as lexnlp 20 | from googleapiclient import discovery 21 | import time 22 | from apiclient.errors import HttpError 23 | 24 | from sklearn.feature_extraction.text import TfidfVectorizer 25 | from sklearn.linear_model import RidgeClassifier 26 | from sys import getsizeof 27 | from joblib import load 28 | 29 | from transformers import pipeline 30 | 31 | base_dir = "./" 32 | 33 | 34 | PIPE = None 35 | 36 | 37 | import os 38 | 39 | # DATA DIRS 40 | DATA_URL = 'pile-of-law/pile-of-law' 41 | 42 | 43 | SPLITS = ['train', 'validation'] 44 | 45 | # Threshold probabilities 46 | PROFANITY_THRESHOLD_PROB = 0.8 47 | 48 | MAX_LEN = int(1000000 / 10) 49 | 50 | 51 | 52 | def profanity_check(sentences): 53 | res = predict_prob(sentences) 54 | return { i : {"PROFANITY" : res[i]} for i in range(len(sentences))} 55 | 56 | def exponential_backoff(req): 57 | for n in range(10): 58 | try: 59 | return client.comments().analyze(body=req).execute() 60 | except HttpError as error: 61 | if error.resp.reason.strip() not in ["Too Many Requests", "Resource has been exhausted (e.g. check quota)."]: 62 | print(error.resp.reason) 63 | import pdb; pdb.set_trace() 64 | raise 65 | if n < 9: 66 | time.sleep((random.random() * 5 * n) + random.random()) 67 | print(f"BACKING OFF {n} for {error.resp.reason}") 68 | else: 69 | raise 70 | else: 71 | break 72 | 73 | def _conver_label_to_score(l): 74 | if l["label"] == "LABEL_0": 75 | return 1 - l["score"] 76 | else: 77 | return l["score"] 78 | 79 | def chunks(lst, n): 80 | """Yield successive n-sized chunks from lst.""" 81 | for i in range(0, len(lst), n): 82 | yield lst[i:i + n] 83 | 84 | 85 | 86 | def scrub_doc(doc): 87 | doc_text = doc["text"] 88 | 89 | # Detect profanity 90 | # profanity-check 91 | sentences = lexnlp.get_sentence_list(doc_text) 92 | 93 | _return = {} 94 | if len(sentences) == 0: 95 | _return = {} 96 | _return["num_sentences"] = len(sentences) 97 | _return["examples"] = json.dumps({}) 98 | _return["confusion"] = json.dumps(np.zeros((4,4)), cls=NumpyEncoder) 99 | 100 | results_pc = profanity_check(sentences) 101 | 102 | _return["sentences"] = json.dumps(sentences) 103 | _return["profanity_check_scores"] = json.dumps(results_pc) 104 | _return["num_sentences"] = len(sentences) 105 | return _return 106 | 107 | 108 | def save_json_file(data, out_dir, filename): 109 | if not os.path.exists(out_dir): 110 | os.makedirs(out_dir) 111 | filepath = os.path.join(out_dir, filename) 112 | with open(filepath, 'w') as out_file: 113 | json.dump(data, out_file) 114 | 115 | class NumpyEncoder(json.JSONEncoder): 116 | def default(self, obj): 117 | if isinstance(obj, np.ndarray): 118 | return obj.tolist() 119 | return json.JSONEncoder.default(self, obj) 120 | 121 | def get_opinions(stuff): 122 | stuff = json.loads(stuff) 123 | return { 124 | "text" : "\n".join(x["text"] for x in stuff["casebody"]["data"]["opinions"]), 125 | "decision_date" : stuff["decision_date"], 126 | "name" : stuff["name"] 127 | } 128 | 129 | def main(args): 130 | split = args.split 131 | name = args.name 132 | print(f"{split}.{name}") 133 | results = [] 134 | with open("data.jsonl", "r") as f: 135 | opinions = f.readlines() 136 | opinions = [get_opinions(x) for x in opinions] 137 | opinions = [x for x in opinions if len(x["text"]) > 3000] 138 | dates = [x["decision_date"] for x in opinions] 139 | names = [x["name"] for x in opinions] 140 | 141 | print(len(opinions)) 142 | from datasets import Dataset 143 | df = Dataset.from_dict({k: [dic[k] for dic in opinions] for k in opinions[0].keys()}) 144 | import sys 145 | import traceback 146 | try: 147 | results2 = df.map(scrub_doc, num_proc=32) 148 | except Exception as e: 149 | import pdb; pdb.set_trace() 150 | del opinions 151 | 152 | total_sentences = 0.0 153 | 154 | 155 | with open(os.path.join(base_dir, "pc_only_scores.csv"), "w") as scores_f: 156 | spamwriter = csv.writer(scores_f) 157 | spamwriter.writerow(["DocIndex", "SentenceIndex", "API", "Category", "Score", "Date", "Name"]) 158 | for results in [results2]: 159 | for i, (result, name, date) in enumerate(zip(results,names, dates)): 160 | for t in ["profanity_check", "perspective"]: 161 | if f"{t}_scores" not in result: 162 | continue 163 | scores = json.loads(result[f"{t}_scores"]) 164 | # doc number 165 | for k, v in scores.items(): 166 | for k2, v2 in v.items(): 167 | datarow = (i, k, t, k2, v2, date, name) 168 | spamwriter.writerow(datarow) 169 | 170 | 171 | 172 | 173 | if __name__ == '__main__': 174 | parser = argparse.ArgumentParser() 175 | parser.add_argument("--split") 176 | parser.add_argument("--name") 177 | 178 | args = parser.parse_args() 179 | main(args) 180 | -------------------------------------------------------------------------------- /toxicity/scotus_only_perspective_pc.py: -------------------------------------------------------------------------------- 1 | # Detects toxic language with unitariyai/detoxify 2 | from collections import defaultdict 3 | from tqdm.contrib.concurrent import process_map # or thread_map 4 | 5 | 6 | import csv 7 | import glob 8 | import os 9 | import json 10 | import numpy as np 11 | import math 12 | import random 13 | import argparse 14 | from profanity_check import predict_prob 15 | from detoxify import Detoxify 16 | from tqdm import tqdm 17 | from datasets import load_dataset 18 | from multiprocessing import Pool, cpu_count 19 | import lexnlp.nlp.en.segments.sentences as lexnlp 20 | from googleapiclient import discovery 21 | import time 22 | from apiclient.errors import HttpError 23 | 24 | from sklearn.feature_extraction.text import TfidfVectorizer 25 | from sklearn.linear_model import RidgeClassifier 26 | from sys import getsizeof 27 | from joblib import load 28 | 29 | from transformers import pipeline 30 | 31 | base_dir = "./" 32 | 33 | 34 | PIPE = None 35 | 36 | 37 | def toxigen_roberta(): 38 | # This will load the pipeline on demand on the current PROCESS/THREAD. 39 | # And load it only once. 40 | global PIPE 41 | if PIPE is None: 42 | PIPE = pipeline("text-classification", model="tomh/toxigen_roberta", device=0) 43 | return PIPE 44 | 45 | import os 46 | 47 | API_KEY = "" 48 | # DATA DIRS 49 | DATA_URL = 'pile-of-law/pile-of-law' 50 | 51 | 52 | client = discovery.build( 53 | "commentanalyzer", 54 | "v1alpha1", 55 | developerKey=API_KEY, 56 | discoveryServiceUrl="https://commentanalyzer.googleapis.com/$discovery/rest?version=v1alpha1", 57 | static_discovery=False, 58 | ) 59 | 60 | SPLITS = ['train', 'validation'] 61 | 62 | # Threshold probabilities 63 | PROFANITY_THRESHOLD_PROB = 0.8 64 | 65 | MAX_LEN = int(1000000 / 10) 66 | 67 | def exponential_backoff(req): 68 | for n in range(10): 69 | try: 70 | return client.comments().analyze(body=req).execute() 71 | except HttpError as error: 72 | if error.resp.reason.strip() not in ["Too Many Requests", "Resource has been exhausted (e.g. check quota)."]: 73 | print(error.resp.reason) 74 | raise 75 | if n < 9: 76 | time.sleep((random.random() * 5 * n) + random.random()) 77 | print(f"BACKING OFF {n} for {error.resp.reason}") 78 | else: 79 | raise 80 | else: 81 | break 82 | 83 | def _conver_label_to_score(l): 84 | if l["label"] == "LABEL_0": 85 | return 1 - l["score"] 86 | else: 87 | return l["score"] 88 | 89 | def chunks(lst, n): 90 | """Yield successive n-sized chunks from lst.""" 91 | for i in range(0, len(lst), n): 92 | yield lst[i:i + n] 93 | 94 | def perspective_check_lexnlp(sentences): 95 | 96 | current_sentence_idx = 0 97 | results = defaultdict(dict) 98 | for sentence in sentences: 99 | # for idx, sentence in enumerate(sentences): 100 | # TODO: batching? 101 | 102 | analyze_request = { 103 | 'comment': { 'text': sentence.strip()}, 104 | 'spanAnnotations' : False, 105 | "languages" : ["en"], 106 | 'requestedAttributes':{"TOXICITY": {}, "SEVERE_TOXICITY": {}, "IDENTITY_ATTACK" : {}, "INSULT" : {}, "PROFANITY" : {}, "THREAT" : {}} 107 | } 108 | try: 109 | response = exponential_backoff(analyze_request) 110 | except Exception as e: 111 | results[current_sentence_idx]["ERROR"] = 1.0 112 | print(f"ERROR ON SENTENCE {sentence} {e}") 113 | 114 | for k, v in response["attributeScores"].items(): 115 | results[current_sentence_idx][k] = v["summaryScore"]['value'] 116 | current_sentence_idx += 1 117 | 118 | return results 119 | 120 | 121 | def scrub_doc(doc): 122 | doc_text = doc["text"] 123 | 124 | sentences = lexnlp.get_sentence_list(doc_text) 125 | results_p = perspective_check_lexnlp(sentences) 126 | _return = {} 127 | if len(sentences) == 0: 128 | _return = {} 129 | _return["num_sentences"] = len(sentences) 130 | _return["examples"] = json.dumps({}) 131 | _return["confusion"] = json.dumps(np.zeros((4,4)), cls=NumpyEncoder) 132 | 133 | _return["sentences"] = json.dumps(sentences) 134 | _return["perspective_scores"] = json.dumps(results_p) 135 | _return["num_sentences"] = len(sentences) 136 | return _return 137 | 138 | 139 | def save_json_file(data, out_dir, filename): 140 | if not os.path.exists(out_dir): 141 | os.makedirs(out_dir) 142 | filepath = os.path.join(out_dir, filename) 143 | with open(filepath, 'w') as out_file: 144 | json.dump(data, out_file) 145 | 146 | class NumpyEncoder(json.JSONEncoder): 147 | def default(self, obj): 148 | if isinstance(obj, np.ndarray): 149 | return obj.tolist() 150 | return json.JSONEncoder.default(self, obj) 151 | 152 | def get_opinions(stuff): 153 | stuff = json.loads(stuff) 154 | return { 155 | "text" : "\n".join(x["text"] for x in stuff["casebody"]["data"]["opinions"]), 156 | "decision_date" : stuff["decision_date"], 157 | "name" : stuff["name"] 158 | } 159 | 160 | def main(args): 161 | split = args.split 162 | name = args.name 163 | print(f"{split}.{name}") 164 | results = [] 165 | with open("data.jsonl", "r") as f: 166 | opinions = f.readlines() 167 | opinions = [get_opinions(x) for x in opinions] 168 | opinions = [x for x in opinions if len(x["text"]) > 3000] 169 | dates = [x["decision_date"] for x in opinions] 170 | names = [x["name"] for x in opinions] 171 | 172 | print(len(opinions)) 173 | from datasets import Dataset 174 | df = Dataset.from_dict({k: [dic[k] for dic in opinions] for k in opinions[0].keys()}) 175 | results2 = df.map(scrub_doc, num_proc=32) 176 | del opinions 177 | 178 | with open(os.path.join(base_dir, "perspective_scores_pc_only.csv"), "w") as scores_f: 179 | spamwriter = csv.writer(scores_f) 180 | spamwriter.writerow(["DocIndex", "SentenceIndex", "API", "Category", "Score", "Date", "Name"]) 181 | for results in [results2]: 182 | for i, (result, name, date) in enumerate(zip(results,names, dates)): 183 | for t in ["perspective"]: 184 | if f"{t}_scores" not in result: 185 | continue 186 | scores = json.loads(result[f"{t}_scores"]) 187 | # doc number 188 | for k, v in scores.items(): 189 | for k2, v2 in v.items(): 190 | datarow = (i, k, t, k2, v2, date, name) 191 | spamwriter.writerow(datarow) 192 | 193 | 194 | 195 | if __name__ == '__main__': 196 | parser = argparse.ArgumentParser() 197 | parser.add_argument("--split") 198 | parser.add_argument("--name") 199 | 200 | args = parser.parse_args() 201 | main(args) 202 | -------------------------------------------------------------------------------- /toxicity/scotus_only_toxigen.py: -------------------------------------------------------------------------------- 1 | # Detects toxic language with unitariyai/toxigen 2 | from collections import defaultdict 3 | from tqdm.contrib.concurrent import process_map # or thread_map 4 | from multiprocess import set_start_method 5 | set_start_method('spawn', force=True) 6 | 7 | 8 | import csv 9 | import glob 10 | import os 11 | import json 12 | import numpy as np 13 | import math 14 | import random 15 | import argparse 16 | from profanity_check import predict_prob 17 | from tqdm import tqdm 18 | from datasets import load_dataset 19 | from multiprocessing import Pool, cpu_count 20 | import lexnlp.nlp.en.segments.sentences as lexnlp 21 | from googleapiclient import discovery 22 | import time 23 | from apiclient.errors import HttpError 24 | 25 | from sklearn.feature_extraction.text import TfidfVectorizer 26 | from sklearn.linear_model import RidgeClassifier 27 | from sys import getsizeof 28 | 29 | from transformers import pipeline 30 | 31 | base_dir = "./" 32 | 33 | 34 | PIPE = None 35 | 36 | DATA_URL = 'pile-of-law/pile-of-law' 37 | 38 | 39 | SPLITS = ['train', 'validation'] 40 | 41 | # Threshold probabilities 42 | PROFANITY_THRESHOLD_PROB = 0.8 43 | 44 | MAX_LEN = int(1000000 / 10) 45 | 46 | 47 | def chunks(lst, n): 48 | """Yield successive n-sized chunks from lst.""" 49 | for i in range(0, len(lst), n): 50 | yield lst[i:i + n] 51 | 52 | 53 | def save_json_file(data, out_dir, filename): 54 | if not os.path.exists(out_dir): 55 | os.makedirs(out_dir) 56 | filepath = os.path.join(out_dir, filename) 57 | with open(filepath, 'w') as out_file: 58 | json.dump(data, out_file) 59 | 60 | class NumpyEncoder(json.JSONEncoder): 61 | def default(self, obj): 62 | if isinstance(obj, np.ndarray): 63 | return obj.tolist() 64 | return json.JSONEncoder.default(self, obj) 65 | 66 | def get_opinions(stuff): 67 | stuff = json.loads(stuff) 68 | return { 69 | "text" : "\n".join(x["text"] for x in stuff["casebody"]["data"]["opinions"]), 70 | "decision_date" : stuff["decision_date"], 71 | "name" : stuff["name"] 72 | } 73 | 74 | def toxigen_roberta(): 75 | # This will load the pipeline on demand on the current PROCESS/THREAD. 76 | # And load it only once. 77 | global PIPE 78 | if PIPE is None: 79 | PIPE = pipeline("text-classification", model="tomh/toxigen_roberta", device=0, return_all_scores=True) 80 | return PIPE 81 | 82 | import os 83 | 84 | def main(args): 85 | set_start_method('spawn', force=True) 86 | split = args.split 87 | name = args.name 88 | print(f"{split}.{name}") 89 | results = [] 90 | with open("data.jsonl", "r") as f: 91 | opinions = f.readlines() 92 | opinions = [get_opinions(x) for x in opinions] 93 | opinions = [x for x in opinions if len(x["text"]) > 3000] 94 | dates = [x["decision_date"] for x in opinions] 95 | names = [x["name"] for x in opinions] 96 | 97 | print(len(opinions)) 98 | from torch.utils.data import Dataset 99 | 100 | class MyDataset(Dataset): 101 | 102 | def __init__(self) -> None: 103 | super().__init__() 104 | self.data = [] 105 | for doc_idx, opinion in enumerate(opinions): 106 | sentences = lexnlp.get_sentence_list(opinion["text"]) 107 | for sentence_idx, sentence in enumerate(sentences): 108 | self.data.append( 109 | { 110 | "sentence_idx" : sentence_idx, 111 | "doc_idx" : doc_idx, 112 | "text" : sentence 113 | }) 114 | def __len__(self): 115 | return len(self.data) 116 | 117 | def __getitem__(self, i): 118 | return self.data[i]["text"] 119 | 120 | 121 | dataset = MyDataset() 122 | 123 | 124 | results1 = [] 125 | cur_doc = defaultdict(dict) 126 | cur_doc_idx = 0 127 | try: 128 | for i, out in enumerate(tqdm(toxigen_roberta()(dataset, batch_size=32, truncation=True, max_length=512), total=len(dataset))): 129 | if cur_doc_idx != dataset.data[i]["doc_idx"]: 130 | cur_doc_idx = dataset.data[i]["doc_idx"] 131 | results1.append({"toxigen_scores" : json.dumps(cur_doc)}) 132 | cur_doc = defaultdict(dict) 133 | cur_doc[dataset.data[i]["sentence_idx"]] = { x["label"] : x["score"] for x in out } 134 | 135 | if len(cur_doc) > 0: 136 | results1.append({"toxigen_scores" : json.dumps(cur_doc)}) 137 | cur_doc = defaultdict(dict) 138 | except: 139 | import pdb; pdb.set_trace() 140 | 141 | 142 | with open(os.path.join(base_dir, "toxigen_scores.csv"), "w") as scores_f: 143 | spamwriter = csv.writer(scores_f) 144 | spamwriter.writerow(["DocIndex", "SentenceIndex", "API", "Category", "Score", "Date", "Name"]) 145 | for results in [results1]: 146 | for i, (result, name, date) in enumerate(zip(results,names, dates)): 147 | for t in ["toxigen"]: 148 | if f"{t}_scores" not in result: 149 | continue 150 | scores = json.loads(result[f"{t}_scores"]) 151 | # doc number 152 | for k, v in scores.items(): 153 | for k2, v2 in v.items(): 154 | datarow = (i, k, t, k2, v2, date, name) 155 | spamwriter.writerow(datarow) 156 | 157 | 158 | if __name__ == '__main__': 159 | parser = argparse.ArgumentParser() 160 | parser.add_argument("--split") 161 | parser.add_argument("--name") 162 | 163 | args = parser.parse_args() 164 | main(args) 165 | -------------------------------------------------------------------------------- /toxicity/scotus_only_unitary.py: -------------------------------------------------------------------------------- 1 | # Detects toxic language with unitariyai/detoxify 2 | from collections import defaultdict 3 | from tqdm.contrib.concurrent import process_map # or thread_map 4 | from multiprocess import set_start_method 5 | set_start_method('spawn', force=True) 6 | 7 | 8 | import csv 9 | import glob 10 | import os 11 | import json 12 | import numpy as np 13 | import math 14 | import random 15 | import argparse 16 | from profanity_check import predict_prob 17 | from tqdm import tqdm 18 | from datasets import load_dataset 19 | from multiprocessing import Pool, cpu_count 20 | import lexnlp.nlp.en.segments.sentences as lexnlp 21 | import time 22 | 23 | from sklearn.feature_extraction.text import TfidfVectorizer 24 | from sklearn.linear_model import RidgeClassifier 25 | from sys import getsizeof 26 | 27 | from transformers import pipeline 28 | 29 | base_dir = "./" 30 | 31 | 32 | PIPE = None 33 | 34 | DATA_URL = 'pile-of-law/pile-of-law' 35 | 36 | 37 | SPLITS = ['train', 'validation'] 38 | 39 | # Threshold probabilities 40 | PROFANITY_THRESHOLD_PROB = 0.8 41 | 42 | MAX_LEN = int(1000000 / 10) 43 | 44 | 45 | def chunks(lst, n): 46 | """Yield successive n-sized chunks from lst.""" 47 | for i in range(0, len(lst), n): 48 | yield lst[i:i + n] 49 | 50 | 51 | def save_json_file(data, out_dir, filename): 52 | if not os.path.exists(out_dir): 53 | os.makedirs(out_dir) 54 | filepath = os.path.join(out_dir, filename) 55 | with open(filepath, 'w') as out_file: 56 | json.dump(data, out_file) 57 | 58 | class NumpyEncoder(json.JSONEncoder): 59 | def default(self, obj): 60 | if isinstance(obj, np.ndarray): 61 | return obj.tolist() 62 | return json.JSONEncoder.default(self, obj) 63 | 64 | def get_opinions(stuff): 65 | stuff = json.loads(stuff) 66 | return { 67 | "text" : "\n".join(x["text"] for x in stuff["casebody"]["data"]["opinions"]), 68 | "decision_date" : stuff["decision_date"], 69 | "name" : stuff["name"] 70 | } 71 | 72 | def unitary_roberta(): 73 | # This will load the pipeline on demand on the current PROCESS/THREAD. 74 | # And load it only once. 75 | global PIPE 76 | if PIPE is None: 77 | PIPE = pipeline("text-classification", model="unitary/unbiased-toxic-roberta", device=0, return_all_scores=True) 78 | return PIPE 79 | 80 | import os 81 | 82 | def main(args): 83 | set_start_method('spawn', force=True) 84 | split = args.split 85 | name = args.name 86 | print(f"{split}.{name}") 87 | results = [] 88 | with open("data.jsonl", "r") as f: 89 | opinions = f.readlines() 90 | opinions = [get_opinions(x) for x in opinions] 91 | opinions = [x for x in opinions if len(x["text"]) > 3000] 92 | dates = [x["decision_date"] for x in opinions] 93 | names = [x["name"] for x in opinions] 94 | 95 | print(len(opinions)) 96 | from torch.utils.data import Dataset 97 | 98 | class MyDataset(Dataset): 99 | 100 | def __init__(self) -> None: 101 | super().__init__() 102 | self.data = [] 103 | for doc_idx, opinion in enumerate(opinions): 104 | sentences = lexnlp.get_sentence_list(opinion["text"]) 105 | for sentence_idx, sentence in enumerate(sentences): 106 | self.data.append( 107 | { 108 | "sentence_idx" : sentence_idx, 109 | "doc_idx" : doc_idx, 110 | "text" : sentence 111 | }) 112 | def __len__(self): 113 | return len(self.data) 114 | 115 | def __getitem__(self, i): 116 | return self.data[i]["text"] 117 | 118 | 119 | dataset = MyDataset() 120 | 121 | 122 | results1 = [] 123 | cur_doc = defaultdict(dict) 124 | cur_doc_idx = 0 125 | try: 126 | for i, out in enumerate(tqdm(unitary_roberta()(dataset, batch_size=32, truncation=True, max_length=512), total=len(dataset))): 127 | if cur_doc_idx != dataset.data[i]["doc_idx"]: 128 | cur_doc_idx = dataset.data[i]["doc_idx"] 129 | results1.append({"detoxify_scores" : json.dumps(cur_doc)}) 130 | cur_doc = defaultdict(dict) 131 | cur_doc[dataset.data[i]["sentence_idx"]] = { x["label"] : x["score"] for x in out } 132 | 133 | if len(cur_doc) > 0: 134 | results1.append({"detoxify_scores" : json.dumps(cur_doc)}) 135 | cur_doc = defaultdict(dict) 136 | except: 137 | import pdb; pdb.set_trace() 138 | 139 | 140 | with open(os.path.join(base_dir, "detoxify_scores.csv"), "w") as scores_f: 141 | spamwriter = csv.writer(scores_f) 142 | spamwriter.writerow(["DocIndex", "SentenceIndex", "API", "Category", "Score", "Date", "Name"]) 143 | for results in [results1]: 144 | for i, (result, name, date) in enumerate(zip(results,names, dates)): 145 | for t in ["detoxify"]: 146 | if f"{t}_scores" not in result: 147 | continue 148 | scores = json.loads(result[f"{t}_scores"]) 149 | # doc number 150 | for k, v in scores.items(): 151 | for k2, v2 in v.items(): 152 | datarow = (i, k, t, k2, v2, date, name) 153 | spamwriter.writerow(datarow) 154 | 155 | 156 | if __name__ == '__main__': 157 | parser = argparse.ArgumentParser() 158 | parser.add_argument("--split") 159 | parser.add_argument("--name") 160 | 161 | args = parser.parse_args() 162 | main(args) 163 | -------------------------------------------------------------------------------- /toxicity/toxigen_context_exp.py: -------------------------------------------------------------------------------- 1 | # Detects toxic language with unitariyai/detoxify 2 | from collections import defaultdict 3 | from tqdm.contrib.concurrent import process_map # or thread_map 4 | from multiprocess import set_start_method 5 | set_start_method('spawn', force=True) 6 | import pandas as pd 7 | 8 | import csv 9 | import glob 10 | import os 11 | import json 12 | import numpy as np 13 | import math 14 | import random 15 | import pickle 16 | import argparse 17 | from profanity_check import predict_prob 18 | from detoxify import Detoxify 19 | from tqdm import tqdm 20 | from datasets import load_dataset 21 | from multiprocessing import Pool, cpu_count 22 | import lexnlp.nlp.en.segments.sentences as lexnlp 23 | from googleapiclient import discovery 24 | import time 25 | from apiclient.errors import HttpError 26 | 27 | from sklearn.feature_extraction.text import TfidfVectorizer 28 | from sklearn.linear_model import RidgeClassifier 29 | from sys import getsizeof 30 | from joblib import load 31 | 32 | from transformers import pipeline 33 | 34 | base_dir = "./" 35 | 36 | PIPE = None 37 | 38 | 39 | def toxigen_roberta(): 40 | # This will load the pipeline on demand on the current PROCESS/THREAD. 41 | # And load it only once. 42 | global PIPE 43 | if PIPE is None: 44 | PIPE = pipeline("text-classification", model="tomh/toxigen_roberta", device=0) 45 | return PIPE 46 | 47 | import os 48 | 49 | # DATA DIRS 50 | DATA_URL = 'pile-of-law/pile-of-law' 51 | 52 | 53 | SPLITS = ['train', 'validation'] 54 | 55 | # Threshold probabilities 56 | PROFANITY_THRESHOLD_PROB = 0.8 57 | 58 | MAX_LEN = int(1000000 / 10) 59 | 60 | 61 | 62 | def _conver_label_to_score(l): 63 | if l["label"] == "LABEL_0": 64 | return 1 - l["score"] 65 | else: 66 | return l["score"] 67 | 68 | def chunks(lst, n): 69 | """Yield successive n-sized chunks from lst.""" 70 | for i in range(0, len(lst), n): 71 | yield lst[i:i + n] 72 | 73 | def toxigen(sentences): 74 | results = {} 75 | predictions = toxigen_roberta()(sentences) 76 | for i in range(len(sentences)): 77 | results[i] = { "TOXICITY" : _conver_label_to_score(predictions[i]) } 78 | return results 79 | 80 | 81 | def save_json_file(data, out_dir, filename): 82 | if not os.path.exists(out_dir): 83 | os.makedirs(out_dir) 84 | filepath = os.path.join(out_dir, filename) 85 | with open(filepath, 'w') as out_file: 86 | json.dump(data, out_file) 87 | 88 | class NumpyEncoder(json.JSONEncoder): 89 | def default(self, obj): 90 | if isinstance(obj, np.ndarray): 91 | return obj.tolist() 92 | return json.JSONEncoder.default(self, obj) 93 | 94 | def get_opinions(stuff): 95 | stuff = json.loads(stuff) 96 | return { 97 | "text" : "\n".join(x["text"] for x in stuff["casebody"]["data"]["opinions"]), 98 | "decision_date" : stuff["decision_date"], 99 | "name" : stuff["name"] 100 | } 101 | 102 | def main(args): 103 | toxigen_scores = pd.read_csv("./perspective_scores_toxigen_only.csv") 104 | toxigen_scores = toxigen_scores.nlargest(5000,['Score']) 105 | toxigen_scores = toxigen_scores[toxigen_scores["Score"] > .5] 106 | print(len(toxigen_scores)) 107 | with open("./sent_mapping.pkl", "rb") as f: 108 | sentence_dict = pickle.load(f) 109 | from torch.utils.data import Dataset 110 | 111 | class MyDataset(Dataset): 112 | 113 | def __init__(self) -> None: 114 | super().__init__() 115 | self.data = [] 116 | for i, row in sorted(toxigen_scores.iterrows(), key=lambda x: x[1]["Score"], reverse=True): 117 | sents = [] 118 | low = max(0, row["SentenceIndex"]-2) 119 | doc_length = max(sentence_dict[row["DocIndex"]].keys())+1 120 | high = min(doc_length, row["SentenceIndex"]+2) 121 | for i in range(low, high): 122 | sents.append(sentence_dict[row["DocIndex"]][i]["sent"]) 123 | sentence = " ".join(sents) 124 | self.data.append( 125 | { 126 | "sentence_idx" : row["SentenceIndex"], 127 | "doc_idx" : row["DocIndex"], 128 | "prev_score" : row["Score"], 129 | "text" : sentence, 130 | "prev_sentence" : sentence_dict[row["DocIndex"]][row["SentenceIndex"]]["sent"] 131 | }) 132 | def __len__(self): 133 | return len(self.data) 134 | 135 | def __getitem__(self, i): 136 | return self.data[i]["text"] 137 | 138 | 139 | dataset = MyDataset() 140 | 141 | 142 | results1 = [] 143 | sentences = [] 144 | prev_sentences = [] 145 | prev_scores = [] 146 | cur_doc = defaultdict(dict) 147 | cur_doc_idx = 0 148 | for i, out in enumerate(tqdm(toxigen_roberta()(dataset, batch_size=32, truncation=True, max_length=512), total=len(dataset))): 149 | if cur_doc_idx != dataset.data[i]["doc_idx"]: 150 | cur_doc_idx = dataset.data[i]["doc_idx"] 151 | results1.append({"toxigen_scores" : json.dumps(cur_doc)}) 152 | cur_doc = defaultdict(dict) 153 | cur_doc[dataset.data[i]["sentence_idx"]] = { "TOXICITY_DIFF" : _conver_label_to_score(out) - dataset.data[i]["prev_score"]} 154 | sentences.append(dataset.data[i]["text"]) 155 | prev_sentences.append(dataset.data[i]["prev_sentence"]) 156 | prev_scores.append(dataset.data[i]["prev_score"]) 157 | if len(cur_doc) > 0: 158 | results1.append({"toxigen_scores" : json.dumps(cur_doc)}) 159 | cur_doc = defaultdict(dict) 160 | 161 | 162 | with open(os.path.join(base_dir, "perspective_scores_toxigen_only_context_exp.csv"), "w") as scores_f: 163 | spamwriter = csv.writer(scores_f) 164 | spamwriter.writerow(["DocIndex", "SentenceIndex", "API", "Category", "Score", "PrevScore", "PrevSentence", "CurSentence"]) 165 | for results in [results1]: 166 | for i, (result, prev_score, prev_sentence, cur_sentence) in enumerate(zip(results, prev_scores, prev_sentences, sentences)): 167 | for t in ["profanity_check", "perspective", "toxigen"]: 168 | if f"{t}_scores" not in result: 169 | continue 170 | scores = json.loads(result[f"{t}_scores"]) 171 | # doc number 172 | for k, v in scores.items(): 173 | for k2, v2 in v.items(): 174 | datarow = (i, k, t, k2, v2, prev_score, prev_sentence, cur_sentence) 175 | spamwriter.writerow(datarow) 176 | 177 | 178 | if __name__ == '__main__': 179 | parser = argparse.ArgumentParser() 180 | parser.add_argument("--split") 181 | parser.add_argument("--name") 182 | 183 | args = parser.parse_args() 184 | main(args) 185 | --------------------------------------------------------------------------------