├── .gitignore
├── README.md
├── dataset_creation
├── .gitignore
├── NLRBScraper.ipynb
├── bar_outlines
│ └── scrape_bar_exam_outlines.py
├── bva
│ └── process_bva.py
├── canadian_decisions
│ ├── grab_canadian_decisions.ipynb
│ └── post_process.py
├── collect_atticus_contracts.py
├── congressional_hearings
│ └── scrape_congressional_hearings.py
├── constitutions
│ └── scrape_constitutions.py
├── courtlistener
│ ├── scrape_and_process_cl_data.py
│ └── scrape_and_process_cl_docket_data.py
├── creative_commons_casebooks
│ └── scrape_cc_casebooks.py
├── creditcardcfpb
│ └── scrape_cfbp_cc_agreements.py
├── dol_ecab
│ └── scrape_dol_ecab.py
├── echr
│ └── scrape_echr.py
├── edgar_contracts
│ └── process_edgar_contracts_dataset.py
├── eoir
│ └── scrape_eoir.py
├── euro_parl.py
├── eurolex
│ └── scrape_eurolex.py
├── federal_register
│ └── process_federal_register.py
├── federal_rules
│ ├── frcp.py
│ └── fre.py
├── founder_docs
│ └── scrape_founding_docs.py
├── ftcadvisories
│ └── scrape_ftc_advisories.py
├── oig_reports
│ └── scrape_oig_reports.py
├── olc
│ └── scrape_olc_memos.py
├── process_cfr.py
├── process_scotus_oral_arguments.py
├── r_legaladvice
│ └── scrape_r_legaladvice.py
├── scotus_dockets
│ └── scrape_scotus_dockets.py
├── scrape_nlrb_decisions.py
├── state_codes
│ ├── process_existing_state_code_files.py
│ └── state_codes_from_scratch.py
├── tax_corpus_jhu
│ └── process_tax_corpus_jhu.py
├── tos
│ └── process_tos.py
├── un_debates
│ └── scrape_un_debates.py
├── us_bills
│ ├── process_us_bills.py
│ ├── run.sh
│ └── split_us_bills.py
└── uscode
│ └── scrape_us_code.py
├── pretraining
├── README.md
├── chunkify_and_hd5.py
├── new_vocab.py
└── pol-finetuning-main
│ ├── README.md
│ ├── environment.yml
│ ├── scripts
│ └── run_casehold.sh
│ └── tasks
│ ├── adam_bias_correct.py
│ ├── casehold.py
│ ├── casehold_bc.py
│ └── casehold_helpers.py
├── privacy
├── README.md
├── eoir
│ ├── EOIR.ipynb
│ ├── EOIR_validation_exp.ipynb
│ ├── README.md
│ ├── causal_exp.py
│ └── create_pseudonyms_dataset.py
└── janedoe
│ ├── README.md
│ ├── jane_doe.py
│ ├── jane_doe_negative.py
│ └── jane_doe_plot.py
├── scrub
├── README.md
├── pii.py
└── scrub.py
└── toxicity
├── README.md
├── create_doc_sent_index.py
├── fig2
├── code
│ └── fig2.R
└── data
│ ├── SCDB.Rdata
│ ├── SCDBLegacy.Rdata
│ ├── SCDB_2018_02_caseCentered_Citation.csv
│ ├── SCDB_2019_01_caseCentered_Citation.Rdata
│ └── SCDB_crosswalk.csv
├── scotus_only_pc_only.py
├── scotus_only_perspective_pc.py
├── scotus_only_toxigen.py
├── scotus_only_unitary.py
└── toxigen_context_exp.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 | dataset_creation/*/cache/
131 | dataset_creation/*/cache_bak
132 | .DS_Store
133 | .tmp
134 | scratch
135 | *sqlite
136 | *.jsonl
137 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # PileOfLaw
2 |
3 | This is the codebase used for the experiments and data scraping tools used for gathering Pile of law.
4 |
5 | Note the Pile of Law dataset itself can be found here: https://huggingface.co/datasets/pile-of-law/pile-of-law
6 |
7 | Pretrained models (though they may be undertrained) can be found here:
8 |
9 | https://huggingface.co/pile-of-law/legalbert-large-1.7M-1
10 |
11 | https://huggingface.co/pile-of-law/legalbert-large-1.7M-2
12 |
13 | Note, the main model we reference in the paper is *-2.
14 |
15 | We can make intermediate checkpoints every 50k steps available on request, we do not upload them all as it is a significant amount of data storage.
16 |
17 | The EOIR privacy pretrained model is available at https://huggingface.co/pile-of-law/distilbert-base-uncased-finetuned-eoir_privacy
18 |
19 | All model cards and datacards are included in the associated repositories. Code for the paper is available as follows.
20 |
21 | ## Dataset Creation
22 |
23 | All of the tools used to scrape every subset of the data lives in the dataset_creation subfolder.
24 |
25 | ## Privacy Experiments
26 |
27 | The privacy experiments live under the privacy folder with a separate Readme explaining those sections.
28 |
29 | ## Toxicity Experiments
30 |
31 | The toxicity experiments live under the toxicity subfolder
32 |
33 | ## Pretraining and fine-tuning
34 |
35 | The pretraining processing scripts and fine-tuning scripts live under the pretraining folder.
36 |
37 | ## Data scrubbing
38 |
39 | We also examine all datasets for SSNs. While we log all private info, we found that most APIs had a significant number of false positives. We narrowly looked for SSNs, but did not encounter any. It is possible the filters we use are not robust, but similarly all datasets we use should already be pre-filtered for such sensitive information also. Scripts for this can be found in the scrub folder.
40 |
41 |
42 | ## Citations
43 |
44 | Please cite our work if you use any of the tools here and/or the data.
45 |
46 | ```
47 | @article{hendersonkrass2022pileoflaw,
48 | title={Pile of Law: Learning Responsible Data Filtering from the Law and a 256GB Open-Source Legal Dataset},
49 | author={Henderson*, Peter and Krass*, Mark and Zheng, Lucia and Guha, Neel and Manning, Christopher and Jurafsky, Dan and Ho, Daniel E},
50 | year={2022}
51 | }
52 | ```
53 |
54 | Some of the datasets in this work are transformed from prior work. Please cite these works as well if you use this dataset:
55 |
56 | ```
57 | @inproceedings{borchmann-etal-2020-contract,
58 | title = "Contract Discovery: Dataset and a Few-Shot Semantic Retrieval Challenge with Competitive Baselines",
59 | author = "Borchmann, {\L}ukasz and
60 | Wisniewski, Dawid and
61 | Gretkowski, Andrzej and
62 | Kosmala, Izabela and
63 | Jurkiewicz, Dawid and
64 | Sza{\l}kiewicz, {\L}ukasz and
65 | Pa{\l}ka, Gabriela and
66 | Kaczmarek, Karol and
67 | Kaliska, Agnieszka and
68 | Grali{\'n}ski, Filip",
69 | booktitle = "Findings of the Association for Computational Linguistics: EMNLP 2020",
70 | month = nov,
71 | year = "2020",
72 | address = "Online",
73 | publisher = "Association for Computational Linguistics",
74 | url = "https://www.aclweb.org/anthology/2020.findings-emnlp.380",
75 | pages = "4254--4268",
76 | abstract = "We propose a new shared task of semantic retrieval from legal texts, in which a so-called contract discovery is to be performed {--} where legal clauses are extracted from documents, given a few examples of similar clauses from other legal acts. The task differs substantially from conventional NLI and shared tasks on legal information extraction (e.g., one has to identify text span instead of a single document, page, or paragraph). The specification of the proposed task is followed by an evaluation of multiple solutions within the unified framework proposed for this branch of methods. It is shown that state-of-the-art pretrained encoders fail to provide satisfactory results on the task proposed. In contrast, Language Model-based solutions perform better, especially when unsupervised fine-tuning is applied. Besides the ablation studies, we addressed questions regarding detection accuracy for relevant text fragments depending on the number of examples available. In addition to the dataset and reference results, LMs specialized in the legal domain were made publicly available.",
77 | }
78 |
79 | @data{T1/N1X6I4_2020,
80 | author = {Blair-Stanek, Andrew and Holzenberger, Nils and Van Durme, Benjamin},
81 | publisher = {Johns Hopkins University Data Archive},
82 | title = "{Tax Law NLP Resources}",
83 | year = {2020},
84 | version = {V2},
85 | doi = {10.7281/T1/N1X6I4},
86 | url = {https://doi.org/10.7281/T1/N1X6I4}
87 | }
88 |
89 | @article{hendrycks2021cuad,
90 | title={CUAD: An Expert-Annotated NLP Dataset for Legal Contract Review},
91 | author={Dan Hendrycks and Collin Burns and Anya Chen and Spencer Ball},
92 | journal={arXiv preprint arXiv:2103.06268},
93 | year={2021}
94 | }
95 |
96 | @inproceedings{koehn2005europarl,
97 | title={Europarl: A parallel corpus for statistical machine translation},
98 | author={Koehn, Philipp and others},
99 | booktitle={MT summit},
100 | volume={5},
101 | pages={79--86},
102 | year={2005},
103 | organization={Citeseer}
104 | }
105 |
106 | @article{DBLP:journals/corr/abs-1805-01217,
107 | author = {Marco Lippi and
108 | Przemyslaw Palka and
109 | Giuseppe Contissa and
110 | Francesca Lagioia and
111 | Hans{-}Wolfgang Micklitz and
112 | Giovanni Sartor and
113 | Paolo Torroni},
114 | title = {{CLAUDETTE:} an Automated Detector of Potentially Unfair Clauses in
115 | Online Terms of Service},
116 | journal = {CoRR},
117 | volume = {abs/1805.01217},
118 | year = {2018},
119 | url = {http://arxiv.org/abs/1805.01217},
120 | archivePrefix = {arXiv},
121 | eprint = {1805.01217},
122 | timestamp = {Mon, 13 Aug 2018 16:49:16 +0200},
123 | biburl = {https://dblp.org/rec/bib/journals/corr/abs-1805-01217},
124 | bibsource = {dblp computer science bibliography, https://dblp.org}
125 | }
126 |
127 | @article{ruggeri2021detecting,
128 | title={Detecting and explaining unfairness in consumer contracts through memory networks},
129 | author={Ruggeri, Federico and Lagioia, Francesca and Lippi, Marco and Torroni, Paolo},
130 | journal={Artificial Intelligence and Law},
131 | pages={1--34},
132 | year={2021},
133 | publisher={Springer}
134 | }
135 |
136 | @inproceedings{10.1145/3462757.3466066,
137 | author = {Huang, Zihan and Low, Charles and Teng, Mengqiu and Zhang, Hongyi and Ho, Daniel E. and Krass, Mark S. and Grabmair, Matthias},
138 | title = {Context-Aware Legal Citation Recommendation Using Deep Learning},
139 | year = {2021},
140 | isbn = {9781450385268},
141 | publisher = {Association for Computing Machinery},
142 | address = {New York, NY, USA},
143 | url = {https://doi.org/10.1145/3462757.3466066},
144 | doi = {10.1145/3462757.3466066},
145 | abstract = {Lawyers and judges spend a large amount of time researching the proper legal authority
146 | to cite while drafting decisions. In this paper, we develop a citation recommendation
147 | tool that can help improve efficiency in the process of opinion drafting. We train
148 | four types of machine learning models, including a citation-list based method (collaborative
149 | filtering) and three context-based methods (text similarity, BiLSTM and RoBERTa classifiers).
150 | Our experiments show that leveraging local textual context improves recommendation,
151 | and that deep neural models achieve decent performance. We show that non-deep text-based
152 | methods benefit from access to structured case metadata, but deep models only benefit
153 | from such access when predicting from context of insufficient length. We also find
154 | that, even after extensive training, RoBERTa does not outperform a recurrent neural
155 | model, despite its benefits of pretraining. Our behavior analysis of the RoBERTa model
156 | further shows that predictive performance is stable across time and citation classes.},
157 | booktitle = {Proceedings of the Eighteenth International Conference on Artificial Intelligence and Law},
158 | pages = {79–88},
159 | numpages = {10},
160 | keywords = {neural natural language processing, legal opinion drafting, citation recommendation, legal text, citation normalization},
161 | location = {S\~{a}o Paulo, Brazil},
162 | series = {ICAIL '21}
163 | }
164 | ```
165 |
--------------------------------------------------------------------------------
/dataset_creation/.gitignore:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Breakend/PileOfLaw/e6f7b11ba52e0fbbd392ad00bf3cc1632ae710b6/dataset_creation/.gitignore
--------------------------------------------------------------------------------
/dataset_creation/bar_outlines/scrape_bar_exam_outlines.py:
--------------------------------------------------------------------------------
1 | import os
2 | import requests
3 | import numpy as np
4 | import json
5 | from urllib.parse import urljoin
6 | from bs4 import BeautifulSoup
7 | import textract
8 | import random
9 | import datetime
10 |
11 | try:
12 | import lzma as xz
13 | except ImportError:
14 | import pylzma as xz
15 |
16 | # Hacky, but this is just for scraping, just uncomment the one you want.
17 | #url = "https://law.stanford.edu/office-of-student-affairs/bar-exam-information/"
18 | url = "https://adamshajnfeld.weebly.com/"
19 |
20 |
21 | headers = {'User-Agent': 'Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/56.0.2924.76 Safari/537.36'} # This is chrome, you can set whatever browser you like
22 | #If there is no such folder, the script will create one automatically
23 | folder_location = './cache/'
24 | if not os.path.exists(folder_location):
25 | os.mkdir(folder_location)
26 |
27 | def save_to_processed(train, val, source_name, out_path):
28 | if not os.path.exists(out_path):
29 | os.makedirs(out_path)
30 | tf = os.path.join(out_path, f"train.{source_name}.jsonl")
31 | with open(tf, mode='w', encoding='utf-8') as out_file:
32 | for line in train:
33 | out_file.write(json.dumps(line) + "\n")
34 | print(f"Written {len(train)} documents to {tf}")
35 |
36 | vf = os.path.join(out_path, f"validation.{source_name}.jsonl")
37 | with open(vf, mode='w', encoding='utf-8') as out_file:
38 | for line in val:
39 | out_file.write(json.dumps(line) + "\n")
40 | print(f"Written {len(val)} documents to {vf}")
41 | # now compress with lib
42 | print("compressing files...")
43 | with open(vf, 'rb') as f, open(vf+".xz", 'wb') as out:
44 | out.write(xz.compress(bytes(f.read())))
45 | with open(tf, 'rb') as f, open(tf+".xz", 'wb') as out:
46 | out.write(xz.compress(bytes(f.read())))
47 | print("compressed")
48 | response = requests.get(url, headers=headers)
49 | soup = BeautifulSoup(response.text, "html.parser")
50 | docs = []
51 | for link in soup.select("a[href$='.doc']"):
52 | #Name the pdf files using the last portion of each link which are unique in this case
53 | filename = os.path.join(folder_location,link['href'].split('/')[-1])
54 | if not os.path.exists(filename):
55 | with open(filename, 'wb') as f:
56 | f.write(requests.get(urljoin(url,link['href']), headers=headers).content)
57 | text = textract.process(filename)
58 |
59 | docs.append({
60 | "url" : link['href'],
61 | "text" : str(text),
62 | "downloaded_timestamp" : datetime.date.today().strftime("%m-%d-%Y"),
63 | "created_timestamp" : "08-30-2006" if "stanford" in url else "02-01-2013" # though we don't have an exact date for when the docs were created,
64 | # the author graduated in 2006 and I assume took the bar then and
65 | # created these for the bar so I'm assuming 2006
66 | # if it's the adam website, he state that he took the bar in feb 2013
67 | })
68 | for link in soup.select("a[href$='.docx']"):
69 | #Name the pdf files using the last portion of each link which are unique in this case
70 | filename = os.path.join(folder_location,link['href'].split('/')[-1])
71 | if not os.path.exists(filename):
72 | with open(filename, 'wb') as f:
73 | f.write(requests.get(urljoin(url,link['href']), headers=headers).content)
74 | text = textract.process(filename)
75 | docs.append({
76 | "url" : link['href'],
77 | "text" : str(text),
78 | "downloaded_timestamp" : datetime.date.today().strftime("%m-%d-%Y"),
79 | "created_timestamp" : "08-30-2006" if "stanford" in url else "02-01-2013" # though we don't have an exact date for when the docs were created,
80 | # the author graduated in 2006 and I assume took the bar then and
81 | # created these for the bar so I'm assuming 2006
82 | # if it's the adam website, he state that he took the bar in feb 2013
83 | })
84 |
85 | random.seed(0) # important for shuffling
86 | rand_idx = list(range(len(docs)))
87 | random.shuffle(rand_idx)
88 |
89 | train_idx = rand_idx[:int(len(rand_idx)*0.75)]
90 | val_idx = rand_idx[int(len(rand_idx)*0.75):]
91 |
92 | train_docs = np.array(docs)[train_idx]
93 | val_docs = np.array(docs)[val_idx]
94 |
95 | if "stanford" in url:
96 | save_to_processed(train_docs, val_docs, "stanfordbarexamoutlines", "./cache/")
97 | elif "adam" in url:
98 | save_to_processed(train_docs, val_docs, "shajnfeldbarexamoutlines", "./cache/")
99 |
--------------------------------------------------------------------------------
/dataset_creation/bva/process_bva.py:
--------------------------------------------------------------------------------
1 | # Processes BVA opinions (2001-2018)
2 | # Raw txt files downloaded through private link shared by authors of this paper: https://arxiv.org/abs/2106.10776
3 | # In the future, the files will be available here: https://drive.google.com/drive/folders/12lAd8Os7VFeqbTKi4wcqJqODjHIn0-yQ?usp=sharing
4 |
5 | import os
6 | import re
7 | import glob
8 | import json
9 | import tarfile
10 | import datetime
11 | import random
12 | from tqdm import tqdm
13 | from dateutil import parser
14 |
15 |
16 | # SOURCE URL
17 | URL = "https://drive.google.com/drive/folders/12lAd8Os7VFeqbTKi4wcqJqODjHIn0-yQ?usp=sharing"
18 |
19 | # DATA DIRS
20 | IN_DIR = "../../data/bva/raw" # path to BVA opinion tar files by year
21 | OUT_DIR = "../../data/bva/processed" # path to write data out to
22 |
23 | # Regex for date match
24 | date_regex = r'(?:\s*Decision\s+Date:\s+)(\d{1,2}/\d{1,2}/\d{1,2})'
25 |
26 | def save_to_file(data, out_dir, fname):
27 | if not os.path.exists(out_dir):
28 | os.makedirs(out_dir)
29 | fpath = os.path.join(out_dir, fname)
30 | with open(fpath, 'w') as out_file:
31 | for x in data:
32 | out_file.write(json.dumps(x) + "\n")
33 | print(f"Written {len(data)} to {fpath}")
34 |
35 | def main():
36 | tar_files = glob.glob(os.path.join(IN_DIR, "*.tar.gz"))
37 |
38 | docs = []
39 | for tar_file in tqdm(tar_files):
40 | print("Processing tar file:", tar_file)
41 | tar = tarfile.open(tar_file, 'r:gz')
42 | for member in tar.getmembers():
43 | if ".txt" in member.name:
44 | f = tar.extractfile(member)
45 | content = f.read()
46 | # Original encoding is in latin1, we want to decode it as utf-8
47 | text = content.decode('latin1').encode('utf-8').decode('utf-8')
48 |
49 | # Extract creation date
50 | lines = text.splitlines()
51 | creation_date = ""
52 | match = None
53 | for line in lines:
54 | match = re.search(date_regex, line)
55 | # If matched date, break at current line and parse / reformat match
56 | if match:
57 | creation_date = parser.parse(match.group(1)).strftime("%m-%d-%Y")
58 |
59 | doc = {
60 | "url": URL,
61 | "created_timestamp": creation_date,
62 | "downloaded_timestamp": datetime.date.today().strftime("%m-%d-%Y"),
63 | "text": text
64 | }
65 | docs.append(doc)
66 |
67 | # Shuffle and split into train / validation
68 | random.seed(0)
69 | random.shuffle(docs)
70 | train = docs[:int(len(docs)*0.75)]
71 | validation = docs[int(len(docs)*0.75):]
72 |
73 | save_to_file(train, OUT_DIR, "train.bva.jsonl")
74 | save_to_file(validation, OUT_DIR, "validation.bva.jsonl")
75 |
76 | if __name__ == '__main__':
77 | main()
--------------------------------------------------------------------------------
/dataset_creation/canadian_decisions/post_process.py:
--------------------------------------------------------------------------------
1 |
2 | from bs4 import BeautifulSoup
3 | import requests_cache
4 | import os
5 | import time
6 | import os
7 | from tqdm import tqdm
8 | import numpy as np
9 | import json
10 | import random
11 | import re
12 | import pickle
13 | import datetime
14 | try:
15 | import lzma as xz
16 | except ImportError:
17 | import pylzma as xz
18 |
19 | overwrite = True
20 | open_type = 'w' if overwrite else 'a'
21 | train_f = xz.open("./train.canadian_decisions.xz", open_type)
22 | val_f = xz.open("./validation.canadian_decisions.xz", open_type)
23 |
24 | with open("canada_cases.pickle", "rb") as f:
25 | _pickled = pickle.load(f)
26 |
27 | for key, value in _pickled.items():
28 | if "The specific page has either moved or is no longer part" in value['text']:
29 | continue
30 |
31 |
32 | datapoint = {
33 | "text" : value['text'],
34 | "created_timestamp" : value['year'],
35 | "downloaded_timestamp" : datetime.date.today().strftime("%m-%d-%Y"),
36 | "url" : "https://www.bccourts.ca/search_judgments.aspx?obd={}&court=1#SearchTitle" if value["jdx"] == "bc" else "https://www.ontariocourts.ca/coa/decisions_main"
37 | }
38 |
39 | if random.random() > .75:
40 | val_f.write((json.dumps(datapoint) + "\n").encode("utf-8"))
41 | else:
42 | train_f.write((json.dumps(datapoint) + "\n").encode("utf-8"))
43 |
44 | train_f.close()
45 | val_f.close()
46 |
--------------------------------------------------------------------------------
/dataset_creation/collect_atticus_contracts.py:
--------------------------------------------------------------------------------
1 | # Code for formatting unlabelled Atticus contract data from: https://github.com/TheAtticusProject/cuad
2 |
3 | import glob
4 | import os
5 | import random
6 | from tqdm import tqdm
7 | import datetime
8 | import json
9 |
10 | IN_DIR = "../../data/atticus_contract/contracts/"
11 | OUT_DIR = "../../data/atticus_contract/processed"
12 |
13 |
14 | def save_to_file(data, fpath):
15 | with open(fpath, "w") as out_file:
16 | for x in data:
17 | out_file.write(json.dumps(x) + "\n")
18 | print(f"Written {len(data)} to {fpath}")
19 |
20 | def main():
21 |
22 | # load contracts
23 | files = glob.glob(os.path.join(IN_DIR, "*", "*.txt"))
24 | print(f"Collected {len(files)} contracts")
25 | docs = []
26 | for f in files:
27 | text = ""
28 | with open(f) as in_file:
29 | for line in in_file:
30 | text = text + line
31 |
32 | doc = {
33 | "url": "https://github.com/TheAtticusProject/cuad",
34 | "created_timestamp": "",
35 | "timestamp": datetime.date.today().strftime("%m-%d-%Y"),
36 | "text": text
37 | }
38 | docs.append(doc)
39 |
40 | # shuffle and split into train / validation
41 | random.seed(0)
42 | random.shuffle(docs)
43 | train = docs[:int(len(docs)*0.75)]
44 | validation = docs[int(len(docs)*0.75):]
45 |
46 | save_to_file(train, os.path.join(OUT_DIR, "train.atticus_contracts.jsonl"))
47 | save_to_file(validation, os.path.join(OUT_DIR, "validation.atticus_contracts.jsonl"))
48 |
49 |
50 |
51 | if __name__ == "__main__":
52 | main()
--------------------------------------------------------------------------------
/dataset_creation/congressional_hearings/scrape_congressional_hearings.py:
--------------------------------------------------------------------------------
1 | import datetime
2 |
3 | import scrapelib
4 | import pytz
5 | import os
6 | import cachetools
7 | import textract
8 | import numpy as np
9 | import random
10 | import json
11 | import requests_cache
12 | try:
13 | import lzma as xz
14 | except ImportError:
15 | import pylzma as xz
16 |
17 | requests = requests_cache.CachedSession('casebriefscache')
18 | headers = {'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_10_1) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/39.0.2171.95 Safari/537.36'}
19 | headers['X-Api-Key'] = os.environ['API_KEY']
20 |
21 | class GovInfo(scrapelib.Scraper):
22 | BASE_URL = 'https://api.govinfo.gov'
23 |
24 | def __init__(self, *args, **kwargs):
25 | super().__init__(*args, **kwargs)
26 | print(os.environ['API_KEY'])
27 | self.headers['X-Api-Key'] = os.environ['API_KEY']
28 |
29 |
30 | def collections(self):
31 | endpoint = '/collections'
32 | response = requests.get(self.BASE_URL + endpoint, headers=headers)
33 | return response.json()
34 |
35 |
36 | def _format_time(self, dt):
37 |
38 | utc_time = dt.astimezone(pytz.utc)
39 | time_str = dt.strftime('%Y-%m-%dT%H:%M:%SZ')
40 |
41 | return time_str
42 |
43 | def congressional_hearings(self, congress=117):
44 |
45 |
46 | partial = f"/collections/CHRG/{self._format_time(datetime.datetime(1970, 1, 1, 0, 0, tzinfo=pytz.utc))}/"
47 | url_template = self.BASE_URL + partial + "{end_time}"
48 | end_time = datetime.datetime.now(pytz.utc)
49 | end_time_str = self._format_time(end_time)
50 |
51 | seen = cachetools.LRUCache(30)
52 | for page in self._pages(url_template, congress, end_time_str):
53 | for package in page['packages']:
54 | package_id = package['packageId']
55 |
56 | if package_id in seen:
57 | continue
58 | else:
59 | # the LRUCache is like a dict, but all we care
60 | # about is whether we've seen this package
61 | # recently, so we just store None as the value
62 | # associated with the package_id key
63 | seen[package_id] = None
64 |
65 | response = requests.get(package['packageLink'], headers=headers)
66 | try:
67 | data = response.json()
68 | except:
69 | data = None
70 |
71 | yield data
72 |
73 | def _download_pdf(self, data):
74 | # if not "html" in data["download"]:
75 | # print(data["download"].keys())
76 | # return
77 | if data is None or data["download"] is None:
78 | return None
79 | url = data["download"]["zipLink"]
80 | tag = url.split("/")[-2]
81 | url = f"https://www.govinfo.gov/content/pkg/{tag}/html/{tag}.htm"
82 | response = requests.get(url, headers=headers)
83 |
84 | if response.status_code != 200:
85 | return None
86 |
87 | text = str(response.content)
88 |
89 | # with open(f'cache/{tag}.pdf', 'wb') as f:
90 | # f.write(response.content)
91 |
92 | # text = str(textract.process(f'cache/{tag}.pdf', method="tesseract"))
93 |
94 | datapoint = {
95 | "text" : text,
96 | "created_timestamp" : data['dateIssued'],
97 | "downloaded_timestamp" : datetime.date.today().strftime("%m-%d-%Y"),
98 | "url" : url
99 | }
100 | return datapoint
101 |
102 | def _pages(self, url_template, congress, end_time):
103 | page_size = 100
104 |
105 | params = {'offset': 0,
106 | 'pageSize': page_size,
107 | "congress" : congress}
108 |
109 | #url = url_template
110 | url = url_template.format(end_time=end_time)
111 | try_count = 0
112 | responded = False
113 | while not responded:
114 | response = requests.get(url, params=params, headers=headers)
115 | if response.status_code != 200:
116 | responded = False
117 | time.sleep(random.random() * 5)
118 | try_count += 1
119 | if try_count >= 3:
120 | responded = True
121 | else:
122 | responded = True
123 | data = response.json()
124 |
125 | yield data
126 |
127 | while len(data['packages']) == page_size:
128 |
129 | # the API results are sorted in descending order by timestamp
130 | # so we can paginate through results by making the end_time
131 | # filter earlier and earlier
132 | earliest_timestamp = data['packages'][-1]['lastModified']
133 | url = url_template.format(end_time=earliest_timestamp)
134 |
135 | response = requests.get(url, params=params, headers=headers)
136 | data = response.json()
137 |
138 | yield data
139 |
140 |
141 | scraper = GovInfo()
142 |
143 | # Prints out all the different types of collections available
144 | # in the govinfo API
145 | print(scraper.collections())
146 |
147 | # Iterate through every congressional hearing
148 | #
149 | # For congressional hearings you need a specify a start
150 | # date time with a timezone
151 | # start_time = datetime.datetime(2020, 1, 1, 0, 0, tzinfo=pytz.utc)
152 | congresses = np.arange(89, 118)[::-1]
153 |
154 | #seen_urls = []
155 | #if os.path.exists("./cache/train.congressional_hearings.xz"):
156 | # with xz.open("./cache/train.congressional_hearings.xz", 'r') as f:
157 | # import pdb; pdb.set_trace()
158 | # seen_urls.extend([x["url"] for x in f.readlines()])
159 | # with xz.open("./cache/validation.congressional_hearings.xz", 'r') as f:
160 | # seen_urls.extend([x["url"] for x in f.readlines()])
161 |
162 | # val_f = xz.open("./cache/validation.congressional_hearings.xz", 'r')
163 | collected_urls = []
164 | i = 0
165 | overwrite = True
166 | open_type = 'w' if overwrite else 'a'
167 | train_f = xz.open("./cache/train.congressional_hearings.xz", open_type)
168 | val_f = xz.open("./cache/validation.congressional_hearings.xz", open_type)
169 | for congress in congresses:
170 | print(f"NOW GETTING CONGRESS {congress}")
171 | for hearing in scraper.congressional_hearings(congress):
172 | datapoint = scraper._download_pdf(hearing)
173 | if datapoint is None:
174 | print("No data for hearing")
175 | print(hearing)
176 | continue
177 | #print(datapoint["url"])
178 | if datapoint["url"] in collected_urls:
179 | print("ALREADY SAW URL!!")
180 | collected_urls.append(datapoint["url"])
181 | i += 1
182 | if i % 1000 == 0:
183 | print(i)
184 | # import pdb; pdb.set_trace()
185 | if random.random() > .75:
186 | val_f.write((json.dumps(datapoint) + "\n").encode("utf-8"))
187 | else:
188 | train_f.write((json.dumps(datapoint) + "\n").encode("utf-8"))
189 |
--------------------------------------------------------------------------------
/dataset_creation/constitutions/scrape_constitutions.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | url= "https://www.constituteproject.org/service/constitutions?in_force=true"
4 | url_template = "https://www.constituteproject.org/constitution/{const_id}?lang=en"
5 |
6 | from bs4 import BeautifulSoup
7 | from bs4 import BeautifulSoup
8 | import requests_cache
9 | import os
10 | import time
11 | import os
12 | import textract
13 | import numpy as np
14 | import json
15 | import random
16 | import datetime
17 | requests = requests_cache.CachedSession('scotus')
18 | headers = {'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_10_1) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/39.0.2171.95 Safari/537.36'}
19 | try:
20 | import lzma as xz
21 | except ImportError:
22 | import pylzma as xz
23 |
24 | cache = "./cache"
25 |
26 | if not os.path.exists(cache):
27 | os.mkdir(cache)
28 |
29 |
30 | response = requests.get(url, headers=headers)
31 | ids = [r["id"] for r in response.json()]
32 | overwrite = True
33 | open_type = 'w' if overwrite else 'a'
34 | train_f = xz.open("./cache/train.constitutions.xz", open_type)
35 | val_f = xz.open("./cache/validation.constitutions.xz", open_type)
36 |
37 | for _id in ids:
38 | print(_id)
39 | url = url_template.format(const_id= _id)
40 | response = requests.get(url, headers=headers)
41 | if not response.from_cache:
42 | time.sleep(random.random()*2)
43 | soup = BeautifulSoup(response.content)
44 | constitution_text_sections = soup.find_all("div", {"class" : "constitution-content__copy"})
45 | title = soup.find("h1", {"class" : "clearfix"}).get_text()
46 | text = title
47 | for section in constitution_text_sections:
48 | text += section.get_text()
49 | text += "\n"
50 |
51 | datapoint = {
52 | "text" : text,
53 | "created_timestamp" : "",
54 | "downloaded_timestamp" : datetime.date.today().strftime("%m-%d-%Y"),
55 | "url" : url
56 | }
57 |
58 | # This dataset is already heavily US biased, so it would be really weird not to train on the US
59 | # constitution
60 | if random.random() > .75 and "United_States" not in _id:
61 | val_f.write((json.dumps(datapoint) + "\n").encode("utf-8"))
62 | else:
63 | train_f.write((json.dumps(datapoint) + "\n").encode("utf-8"))
64 |
65 |
--------------------------------------------------------------------------------
/dataset_creation/courtlistener/scrape_and_process_cl_data.py:
--------------------------------------------------------------------------------
1 | import wget
2 | import json
3 | import os
4 | import tarfile
5 | from tqdm import tqdm
6 | from pathlib import Path
7 | import io
8 | import json
9 | import tempfile
10 | import datetime
11 | import bs4
12 | import random
13 | import shutil
14 |
15 |
16 | try:
17 | import lzma as xz
18 | except ImportError:
19 | import pylzma as xz
20 | url = "https://www.courtlistener.com/api/bulk-data/opinions/all.tar"
21 |
22 | if not os.path.exists("./cache/all.tar"):
23 | import wget
24 | filename = wget.download(url, out="./cache/")
25 |
26 |
27 |
28 | def idempotent(x):
29 | return x
30 |
31 |
32 | def html2text(x):
33 | soup = bs4.BeautifulSoup(x, "lxml")
34 | return soup.get_text()
35 |
36 | field_order = [
37 | ("plain_text", idempotent),
38 | ("html", html2text),
39 | ("html_lawbox", html2text),
40 | ("html_columbia", html2text),
41 | ("html_with_citations", html2text),
42 | ("xml_harvard", html2text)
43 | ]
44 |
45 | error_str = (
46 | "Unable to extract the content from this file. Please try reading the original."
47 | )
48 |
49 |
50 |
51 | def parse_json(item):
52 | """ From https://github.com/thoppe/The-Pile-FreeLaw/blob/master/P1_extract_text.py
53 | """
54 |
55 | js = json.loads(item)
56 |
57 | text = None
58 |
59 | if "html" in js and js["html"] == error_str:
60 | return None
61 |
62 | for k, func in field_order:
63 | if k in js and isinstance(js[k], str) and len(js[k]):
64 | text = func(js[k])
65 |
66 | if text is None:
67 | print(f"Skipping {item}, couldn't find text.")
68 | return None
69 |
70 | return {
71 | "url" : js['resource_uri'],
72 | "text" : text,
73 | "downloaded_timestamp" : datetime.date.today().strftime("%m-%d-%Y"),
74 | "created_timestamp" : datetime.datetime.strptime(js['date_created'].split("T")[0], '%Y-%m-%d').strftime("%m-%d-%Y")
75 | }
76 |
77 | train = 0
78 | val = 0
79 |
80 | with xz.open("./cache/train.courtlisteneropinions.xz", 'w') as train_f:
81 | with xz.open("./cache/validation.courtlisteneropinions.xz", 'w') as val_f:
82 | with tarfile.open("./cache/all.tar") as all_tar:
83 | for jxd in all_tar.getmembers():
84 | with tarfile.open(fileobj=all_tar.extractfile(jxd)) as jxd_tar:
85 | for opinion in jxd_tar.getmembers():
86 | datapoint = parse_json(jxd_tar.extractfile(opinion).read())
87 | if random.random() > .75:
88 | val_f.write((json.dumps(datapoint) + "\n").encode("utf-8"))
89 | val += 1
90 | else:
91 | train_f.write((json.dumps(datapoint) + "\n").encode("utf-8"))
92 | train += 1
93 |
94 | total, used, free = shutil.disk_usage("/")
95 |
96 | if train % 10000 == 0:
97 | print(f"Have {train} documents and {val} validation documents!")
98 | print("Total: %d GiB" % (total // (2**30)))
99 | print("Used: %d GiB" % (used // (2**30)))
100 | print("Free: %d GiB" % (free // (2**30)))
101 |
102 | if (free // (2**30)) < 30:
103 | print("RUNNING OUT OF DISK!!!")
104 | if (free // (2**30)) < 25:
105 | print("Too little disk!!!")
106 | break
107 |
108 | print(f"Have {train} documents and {val} validation documents!")
109 |
110 |
111 |
--------------------------------------------------------------------------------
/dataset_creation/courtlistener/scrape_and_process_cl_docket_data.py:
--------------------------------------------------------------------------------
1 | import wget
2 | import json
3 | import os
4 | import tarfile
5 | from tqdm import tqdm
6 | from pathlib import Path
7 | import io
8 | import json
9 | import tempfile
10 | import datetime
11 | import bs4
12 | import random
13 | import shutil
14 |
15 |
16 | try:
17 | import lzma as xz
18 | except ImportError:
19 | import pylzma as xz
20 | # url = "https://www.courtlistener.com/api/bulk-data/opinions/all.tar"
21 |
22 | # if not os.path.exists("./cache/all.tar"):
23 | # import wget
24 | # filename = wget.download(url, out="./cache/")
25 |
26 |
27 |
28 | def idempotent(x):
29 | return x
30 |
31 |
32 | def html2text(x):
33 | soup = bs4.BeautifulSoup(x, "lxml")
34 | return soup.get_text()
35 |
36 | field_order = [
37 | ("plain_text", idempotent),
38 | ("html", html2text),
39 | ("html_lawbox", html2text),
40 | ("html_columbia", html2text),
41 | ("html_with_citations", html2text),
42 | ("xml_harvard", html2text)
43 | ]
44 |
45 | error_str = (
46 | "Unable to extract the content from this file. Please try reading the original."
47 | )
48 |
49 | import requests
50 | import time
51 | import random
52 | import datetime
53 | import json
54 | import os
55 |
56 | def requestJSON(url):
57 | while True:
58 | try:
59 | r = requests.get(url, headers={'Authorization': f'Token {os.environ["API_KEY"]}' })
60 | if r.status_code != 200:
61 | print('error code', r.status_code)
62 | time.sleep(5)
63 | continue
64 | else:
65 | break
66 | except Exception as e:
67 | print(e)
68 | time.sleep(5)
69 | continue
70 | return r.json()
71 |
72 |
73 | next_page = "https://www.courtlistener.com/api/rest/v3/recap-documents/?is_available=true"
74 |
75 | val=0
76 | train=0
77 | from dateutil.relativedelta import *
78 | import datetime
79 | import datefinder
80 |
81 | if os.path.exists("./cache/cur_url.txt"):
82 | with open("./cache/cur_url.txt", "r") as f:
83 | next_page = f.read().strip()
84 | dates = list(datefinder.find_dates(next_page))
85 | cur_month = dates[0]
86 | prev_month = dates[1]
87 | if cur_month < prev_month:
88 | tmp = prev_month
89 | prev_month = cur_month
90 | cur_month = tmp
91 | else:
92 | cur_month = datetime.datetime.now()
93 | prev_month = cur_month - relativedelta(days=3)
94 | next_page = f"https://www.courtlistener.com/api/rest/v3/docket-entries/?date_filed__lt={cur_month.strftime('%Y-%m-%d')}&date_filed__gt={prev_month.strftime('%Y-%m-%d')}&fields=date_filed%2Crecap_documents%2Cdescription&recap_documents__is_available=true"
95 |
96 |
97 | with xz.open("./cache/train.courtlistenerdocketentries.xz", 'a') as train_f:
98 | with xz.open("./cache/validation.courtlistenerdocketentries.xz", 'a') as val_f:
99 | while True:
100 | #print(cur_month.strftime('%Y-%m-%d'))
101 | if next_page is None:
102 | next_page = f"https://www.courtlistener.com/api/rest/v3/docket-entries/?date_filed__lt={cur_month.strftime('%Y-%m-%d')}&date_filed__gt={prev_month.strftime('%Y-%m-%d')}&fields=date_filed%2Crecap_documents%2Cdescription&recap_documents__is_available=true"
103 | while next_page is not None:
104 | print(next_page)
105 | js_data = requestJSON(next_page)
106 | if 'count' in js_data:
107 | print(js_data['count'])
108 | time.sleep(random.random()*3)
109 | next_page = js_data["next"]
110 | if next_page is not None:
111 | with open('./cache/cur_url.txt', 'w') as f:
112 | f.write(next_page)
113 | for docket_entry in js_data["results"]:
114 | for recap_data in docket_entry["recap_documents"]:
115 | if "plain_text" in recap_data and recap_data["plain_text"]:
116 | datapoint = {
117 | "url" : recap_data['resource_uri'],
118 | "text" : recap_data["plain_text"],
119 | "downloaded_timestamp" : datetime.date.today().strftime("%m-%d-%Y"),
120 | "created_timestamp" : docket_entry['date_filed']
121 | }
122 | if random.random() > .75:
123 | val_f.write((json.dumps(datapoint) + "\n").encode("utf-8"))
124 | val += 1
125 | else:
126 | train_f.write((json.dumps(datapoint) + "\n").encode("utf-8"))
127 | train += 1
128 | if train % 5000 == 0:
129 | print(f"Have {train} documents and {val} validation documents!")
130 | cur_month = prev_month
131 | prev_month = cur_month - relativedelta(days=3)
132 |
133 |
134 |
--------------------------------------------------------------------------------
/dataset_creation/creative_commons_casebooks/scrape_cc_casebooks.py:
--------------------------------------------------------------------------------
1 | import zipfile
2 | import os
3 | import datetime
4 | import textract
5 | import random
6 | import json
7 | try:
8 | import lzma as xz
9 | except ImportError:
10 | import pylzma as xz
11 | overwrite = True
12 | open_type = 'w' if overwrite else 'a'
13 | train_f = xz.open("./cache/train.cc_casebooks.xz", open_type)
14 | val_f = xz.open("./cache/validation.cc_casebooks.xz", open_type)
15 | # ASSUMES THAT YOUVE DOWNLOADED THE TEXTBOOKS INTO THE CACHE AND THAT THEY'RE ALL CC LICENSE
16 | docs = []
17 | for path, subdirs, files in os.walk("./cache/"):
18 | for name in files:
19 | if not name.endswith(".pdf"):
20 | continue
21 | text = str(textract.process(os.path.join(path, name)))
22 | datapoint = {
23 | "text" : text,
24 | "created_timestamp" : "",
25 | "downloaded_timestamp" : datetime.date.today().strftime("%m-%d-%Y"),
26 | "url" : "https://open.umn.edu/opentextbooks/"
27 | }
28 |
29 | if random.random() > .75:
30 | val_f.write((json.dumps(datapoint) + "\n").encode("utf-8"))
31 | else:
32 | train_f.write((json.dumps(datapoint) + "\n").encode("utf-8"))
--------------------------------------------------------------------------------
/dataset_creation/creditcardcfpb/scrape_cfbp_cc_agreements.py:
--------------------------------------------------------------------------------
1 | import zipfile
2 | import os
3 | import datetime
4 | import textract
5 | import random
6 | import json
7 | try:
8 | import lzma as xz
9 | except ImportError:
10 | import pylzma as xz
11 | url = "https://files.consumerfinance.gov/a/assets/Credit_Card_Agreements_2021_Q2.zip"
12 |
13 | cache_path = "./cache/Credit_Card_Agreements_2021_Q2.zip"
14 | if not os.path.exists(cache_path):
15 | import wget
16 | filename = wget.download(url, out="./cache/")
17 |
18 | with zipfile.ZipFile(cache_path, 'r') as zip_ref:
19 | zip_ref.extractall("./cache/")
20 |
21 | overwrite = True
22 | open_type = 'w' if overwrite else 'a'
23 | train_f = xz.open("./cache/train.cfpb_cc.xz", open_type)
24 | val_f = xz.open("./cache/validation.cfpb_cc.xz", open_type)
25 |
26 | docs = []
27 | for path, subdirs, files in os.walk("./cache/"):
28 | for name in files:
29 | if not name.endswith(".pdf"):
30 | continue
31 | text = str(textract.process(os.path.join(path, name)))
32 | datapoint = {
33 | "text" : text,
34 | "created_timestamp" : "2021",
35 | "downloaded_timestamp" : datetime.date.today().strftime("%m-%d-%Y"),
36 | "url" : url
37 | }
38 |
39 | if random.random() > .75:
40 | val_f.write((json.dumps(datapoint) + "\n").encode("utf-8"))
41 | else:
42 | train_f.write((json.dumps(datapoint) + "\n").encode("utf-8"))
--------------------------------------------------------------------------------
/dataset_creation/dol_ecab/scrape_dol_ecab.py:
--------------------------------------------------------------------------------
1 | from bs4 import BeautifulSoup
2 | import requests_cache
3 | import os
4 | import time
5 | import os
6 | import textract
7 | import numpy as np
8 | import json
9 | import random
10 | import datetime
11 | import re
12 | import pandas as pd
13 |
14 | try:
15 | import lzma as xz
16 | except ImportError:
17 | import pylzma as xz
18 |
19 | cache = "./cache"
20 |
21 | requests = requests_cache.CachedSession('scotus')
22 | headers = {'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_10_1) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/39.0.2171.95 Safari/537.36'}
23 |
24 | if not os.path.exists(cache):
25 | os.mkdir(cache)
26 |
27 | docs = []
28 | base_url = "https://www.dol.gov/agencies/ecab/decisions/{year}/{month}"
29 | pdf_url = "https://www.dol.gov/sites/dolgov/files/ecab/decisions/{year}/{month}/{tag}.pdf"
30 |
31 | docs_for_pseudonyms = []
32 |
33 | num_pseudonyms = 0
34 |
35 | for vol in range(2007,2023):
36 | for month in ["Jan", "Feb", "Mar", "Apr", "May", "Jun", "Jul", "Aug", "Sep", "Oct", "Nov", "Dec"]:
37 | url = base_url.format(year=vol, month=month)
38 | # page2 = requests.get(url, headers=headers)
39 | # soup = BeautifulSoup(page2.text,"lxml")
40 | converters = {c: lambda x: str(x) for c in range(3)}
41 | try:
42 | table = pd.read_html(url, converters=converters)
43 | except:
44 | print(f"Skipping {month} {vol} b/c table formatting issue.")
45 |
46 | for row in table[0].iterrows():
47 | tag, date, casename = row[1].values
48 |
49 | link = pdf_url.format(year=vol, month=month, tag=tag)
50 |
51 | print(tag)
52 | if not os.path.exists(f'{cache}/{tag}.pdf'):
53 | print(link)
54 | pdf = requests.get(link, headers=headers)
55 |
56 | with open(f'{cache}/{tag}.pdf', 'wb') as f:
57 | f.write(pdf.content)
58 |
59 | try:
60 | text = textract.process(f'{cache}/{tag}.pdf', encoding='utf-8')
61 | text = text.decode("utf8")
62 | except:
63 | print(f"Skipping {tag}!")
64 | continue
65 |
66 | try:
67 | datetime_object = datetime.datetime.strptime(date, '%B %d, %Y')
68 | except:
69 | try:
70 | datetime_object = datetime.datetime.strptime(date, '%B%d, %Y')
71 | except:
72 | date = " ".join(date.split(" ")[:-1])
73 | try:
74 | datetime_object = datetime.datetime.strptime(date, '%B %d, %Y')
75 | except:
76 | continue
77 | timestamp = datetime_object.strftime("%m-%d-%Y")
78 |
79 | if len(text) < 100:
80 | import pdb; pdb.set_trace()
81 |
82 | docs.append({
83 | "text" : text,
84 | "created_timestamp" : timestamp,
85 | "downloaded_timestamp" : datetime.date.today().strftime("%m-%d-%Y"),
86 | "url" : link
87 | })
88 |
89 |
90 | def save_to_processed(train, val, source_name, out_path):
91 | if not os.path.exists(out_path):
92 | os.makedirs(out_path)
93 | tf = os.path.join(out_path, f"train.{source_name}.jsonl")
94 | with open(tf, mode='w', encoding='utf-8') as out_file:
95 | for line in train:
96 | out_file.write(json.dumps(line) + "\n")
97 | print(f"Written {len(train)} documents to {tf}")
98 |
99 | vf = os.path.join(out_path, f"validation.{source_name}.jsonl")
100 | with open(vf, mode='w', encoding='utf-8') as out_file:
101 | for line in val:
102 | out_file.write(json.dumps(line) + "\n")
103 | print(f"Written {len(val)} documents to {vf}")
104 |
105 | # now compress with lib
106 | print("compressing files...")
107 | with open(vf, 'rb') as f, open(vf+".xz", 'wb') as out:
108 | out.write(xz.compress(bytes(f.read())))
109 | with open(tf, 'rb') as f, open(tf+".xz", 'wb') as out:
110 | out.write(xz.compress(bytes(f.read())))
111 | print("compressed")
112 |
113 | random.seed(0) # important for shuffling
114 | rand_idx = list(range(len(docs)))
115 | random.shuffle(rand_idx)
116 |
117 | train_idx = rand_idx[:int(len(rand_idx)*0.75)]
118 | val_idx = rand_idx[int(len(rand_idx)*0.75):]
119 |
120 | train_docs = np.array(docs)[train_idx]
121 | val_docs = np.array(docs)[val_idx]
122 |
123 | save_to_processed(train_docs, val_docs, "dol_ecab", "./cache/")
124 |
--------------------------------------------------------------------------------
/dataset_creation/echr/scrape_echr.py:
--------------------------------------------------------------------------------
1 | import zipfile
2 | import os
3 | import datetime
4 | try:
5 | import lzma as xz
6 | except ImportError:
7 | import pylzma as xz
8 | url = "https://archive.org/download/ECHR-ACL2019/ECHR_Dataset.zip"
9 |
10 | cache_path = "./cache/ECHR_Dataset.zip"
11 | if not os.path.exists(cache_path):
12 | import wget
13 | filename = wget.download(url, out="./cache/")
14 |
15 | with zipfile.ZipFile(cache_path, 'r') as zip_ref:
16 | zip_ref.extractall("./cache/")
17 |
18 |
19 | overwrite = True
20 | open_type = 'w' if overwrite else 'a'
21 | train_f = xz.open("./cache/train.echr.xz", open_type)
22 | val_f = xz.open("./cache/validation.echr.xz", open_type)
23 | import os, json
24 |
25 | path_to_json = './cache/EN_train'
26 | json_files = [pos_json for pos_json in os.listdir(path_to_json) if pos_json.endswith('.json')]
27 |
28 | for json_file in json_files:
29 | with open('./cache/EN_train/' + json_file, "r") as f:
30 | loaded = json.loads(f.read())
31 |
32 | blocklist = ["ITEMID"]
33 | text = ""
34 | for key, val in loaded.items():
35 | if val != "" and val is not None:
36 | if not isinstance(val, list):
37 | text += f"{key}: {val}\n"
38 | else:
39 | if len(val) > 0:
40 | joined = '\n'.join(val)
41 | text += f"{key}: {joined}\n"
42 | datapoint = {
43 | "url" : url,
44 | "text" : text,
45 | "downloaded_timestamp" : datetime.date.today().strftime("%m-%d-%Y"),
46 | "created_timestamp" : loaded["DATE"]
47 | }
48 |
49 | train_f.write((json.dumps(datapoint) + "\n").encode("utf-8"))
50 |
51 |
52 | path_to_json = 'cache/EN_dev'
53 | json_files = [pos_json for pos_json in os.listdir(path_to_json) if pos_json.endswith('.json')]
54 |
55 | for json_file in json_files:
56 | with open( 'cache/EN_dev/'+ json_file, "r") as f:
57 | loaded = json.loads(f.read())
58 |
59 | blocklist = ["ITEMID"]
60 | text = ""
61 | for key, val in loaded.items():
62 | if val != "" and val is not None:
63 | if not isinstance(val, list):
64 | text += f"{key}: {val}\n"
65 | else:
66 | if len(val) > 0:
67 | joined= '\n'.join(val)
68 | text += f"{key}: {joined}\n"
69 | datapoint = {
70 | "url" : url,
71 | "text" : text,
72 | "downloaded_timestamp" : datetime.date.today().strftime("%m-%d-%Y"),
73 | "created_timestamp" : loaded["DATE"]
74 | }
75 |
76 | val_f.write((json.dumps(datapoint) + "\n").encode("utf-8"))
77 |
78 |
--------------------------------------------------------------------------------
/dataset_creation/edgar_contracts/process_edgar_contracts_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import datetime
3 | import random
4 | import numpy as np
5 | import json
6 |
7 |
8 | try:
9 | import lzma as xz
10 | except ImportError:
11 | import pylzma as xz
12 | url = "https://applica-public.s3-eu-west-1.amazonaws.com/contract-discovery/edgar.txt.xz"
13 |
14 | if not os.path.exists("./cache/edgar.txt.xz"):
15 | import wget
16 | filename = wget.download(url, out="./cache/")
17 |
18 | def save_to_processed(train, val, source_name, out_path):
19 | if not os.path.exists(out_path):
20 | os.makedirs(out_path)
21 | tf = os.path.join(out_path, f"train.{source_name}.jsonl")
22 | with open(tf, mode='w', encoding='utf-8') as out_file:
23 | for line in train:
24 | out_file.write(json.dumps(line) + "\n")
25 | print(f"Written {len(train)} documents to {tf}")
26 |
27 | vf = os.path.join(out_path, f"validation.{source_name}.jsonl")
28 | with open(vf, mode='w', encoding='utf-8') as out_file:
29 | for line in val:
30 | out_file.write(json.dumps(line) + "\n")
31 | print(f"Written {len(val)} documents to {vf}")
32 | # now compress with lib
33 | print("compressing files...")
34 | with open(vf, 'rb') as f, open(vf+".xz", 'wb') as out:
35 | out.write(xz.compress(bytes(f.read())))
36 | with open(tf, 'rb') as f, open(tf+".xz", 'wb') as out:
37 | out.write(xz.compress(bytes(f.read())))
38 | print("compressed")
39 |
40 | docs = []
41 |
42 | with xz.open('./cache/edgar.txt.xz', mode='rt') as f:
43 | for line in f:
44 | docs.append({
45 | "url" : url,
46 | "text" : line,
47 | "downloaded_timestamp" : datetime.date.today().strftime("%m-%d-%Y"),
48 | "created_timestamp" : "" # Unfortunately the dataset that originally scraped the material didn't keep the date.
49 | })
50 |
51 | random.seed(0) # important for shuffling
52 | rand_idx = list(range(len(docs)))
53 | random.shuffle(rand_idx)
54 |
55 | train_idx = rand_idx[:int(len(rand_idx)*0.75)]
56 | val_idx = rand_idx[int(len(rand_idx)*0.75):]
57 |
58 | train_docs = np.array(docs)[train_idx]
59 | val_docs = np.array(docs)[val_idx]
60 |
61 | save_to_processed(train_docs, val_docs, "edgar", "./cache/")
62 |
--------------------------------------------------------------------------------
/dataset_creation/eoir/scrape_eoir.py:
--------------------------------------------------------------------------------
1 | from bs4 import BeautifulSoup
2 | import requests_cache
3 | import os
4 | import time
5 | import os
6 | import textract
7 | import numpy as np
8 | import json
9 | import random
10 | import datetime
11 | import re
12 |
13 | try:
14 | import lzma as xz
15 | except ImportError:
16 | import pylzma as xz
17 |
18 | cache = "./cache"
19 |
20 | requests = requests_cache.CachedSession('scotus')
21 | headers = {'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_10_1) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/39.0.2171.95 Safari/537.36'}
22 |
23 | if not os.path.exists(cache):
24 | os.mkdir(cache)
25 |
26 | docs = []
27 | base_url_27 = "https://www.justice.gov/eoir/volume-{vol:02}"
28 | base_url_26 = "https://www.justice.gov/eoir/precedent-decisions-volume-{vol:02}"
29 |
30 | docs_for_pseudonyms = []
31 |
32 | num_pseudonyms = 0
33 |
34 | for vol in range(8,29):
35 | if vol >= 27:
36 | url = base_url_27.format(vol=vol)
37 | else:
38 | url = base_url_26.format(vol=vol)
39 | page2 = requests.get(url, headers=headers)
40 | soup = BeautifulSoup(page2.text,"lxml")
41 | g_dates = soup.findAll("td", attrs={'class': None, "colspan" : None})
42 | g_data = soup.findAll("td", {"class" : "rteright"})
43 |
44 | for data, date in zip(g_data, g_dates):
45 | # data = data.find("a")
46 | if "justice.gov" in data.find("a")["href"]:
47 | link = data.find("a")["href"]
48 | else:
49 | link = "https://www.justice.gov/" + data.find("a")["href"]
50 | if link.endswith(".pdf"):
51 | tag = link.split("/")[-1].replace(".pdf", "")
52 | else:
53 | tag = link.split("/")[-2]
54 | tag += "__" + str(vol)
55 | print(tag)
56 | if not os.path.exists(f'{cache}/{tag}.pdf'):
57 | print(link)
58 | pdf = requests.get(link)
59 |
60 | with open(f'{cache}/{tag}.pdf', 'wb') as f:
61 | f.write(pdf.content)
62 |
63 | text = textract.process(f'{cache}/{tag}.pdf', encoding='utf-8')
64 | text = text.decode("utf8")
65 |
66 | try:
67 | name = date.find("strong").text.replace(",", "")
68 | except:
69 | try:
70 | name = date.find("b").text.replace(",", "")
71 | except:
72 | import pdb; pdb.set_trace()
73 |
74 | print(name)
75 |
76 | def check_pseudo(ns):
77 | ns = ns.replace("et al.", "").strip()
78 | ns = ns.replace("‑", "-")
79 | if ns == "DEF-" or ns == "D'O-" or ns == "DEN-" or ns == "DEG" or ns == "DE M-" or ns == "D-S- INC." or ns == "DIP-":
80 | return True
81 | for n in re.split('(&|and|AND)',ns):
82 | n = n.strip()
83 | n = n.replace(" ", "")
84 | created_pseudo = "-".join(n.replace("-", ""))
85 | if created_pseudo == n or created_pseudo + "-" == n:
86 | return True
87 | created_pseudo = ".".join(n.replace(".", ""))
88 | if created_pseudo == n:
89 | return True
90 | return False
91 |
92 | is_pseudonym = check_pseudo(name)
93 | print(is_pseudonym)
94 | # if not is_pseudonym:
95 | # import pdb; pdb.set_trace()
96 |
97 | if is_pseudonym:
98 | num_pseudonyms +=1
99 |
100 | issuance_date = date.text.split("(")[-1].split(" ")[-1]
101 | issuance_date = issuance_date.replace(")", "").strip()
102 | print(issuance_date)
103 | if issuance_date == "":
104 | import pdb; pdb.set_trace()
105 |
106 | if len(text) < 100:
107 | import pdb; pdb.set_trace()
108 |
109 | docs.append({
110 | "text" : text,
111 | "created_timestamp" : issuance_date,
112 | "downloaded_timestamp" : datetime.date.today().strftime("%m-%d-%Y"),
113 | "url" : link
114 | })
115 |
116 | docs_for_pseudonyms.append({
117 | "text" : text,
118 | "created_timestamp" : issuance_date,
119 | "downloaded_timestamp" : datetime.date.today().strftime("%m-%d-%Y"),
120 | "url" : link,
121 | "name" : name,
122 | "is_pseudonym" : is_pseudonym
123 | })
124 |
125 |
126 | def save_to_processed(train, val, source_name, out_path):
127 | if not os.path.exists(out_path):
128 | os.makedirs(out_path)
129 | tf = os.path.join(out_path, f"train.{source_name}.jsonl")
130 | with open(tf, mode='w', encoding='utf-8') as out_file:
131 | for line in train:
132 | out_file.write(json.dumps(line) + "\n")
133 | print(f"Written {len(train)} documents to {tf}")
134 |
135 | vf = os.path.join(out_path, f"validation.{source_name}.jsonl")
136 | with open(vf, mode='w', encoding='utf-8') as out_file:
137 | for line in val:
138 | out_file.write(json.dumps(line) + "\n")
139 | print(f"Written {len(val)} documents to {vf}")
140 |
141 | # now compress with lib
142 | print("compressing files...")
143 | with open(vf, 'rb') as f, open(vf+".xz", 'wb') as out:
144 | out.write(xz.compress(bytes(f.read())))
145 | with open(tf, 'rb') as f, open(tf+".xz", 'wb') as out:
146 | out.write(xz.compress(bytes(f.read())))
147 | print("compressed")
148 |
149 | random.seed(0) # important for shuffling
150 | rand_idx = list(range(len(docs)))
151 | random.shuffle(rand_idx)
152 |
153 | train_idx = rand_idx[:int(len(rand_idx)*0.75)]
154 | val_idx = rand_idx[int(len(rand_idx)*0.75):]
155 |
156 | train_docs = np.array(docs)[train_idx]
157 | val_docs = np.array(docs)[val_idx]
158 |
159 | save_to_processed(train_docs, val_docs, "eoir", "./cache/")
160 |
161 | save_to_processed(docs_for_pseudonyms, [], "eoir_pseudonym", "./cache/")
162 |
163 | print("NUM PSEUDOI")
164 | print(num_pseudonyms)
165 |
--------------------------------------------------------------------------------
/dataset_creation/euro_parl.py:
--------------------------------------------------------------------------------
1 | # Processes data from the European Parliament Proceedings Parallel Corpus 1996-2011 from
2 | # https://www.statmt.org/europarl/. We only use the English data. We strip out all tags.
3 |
4 | from bs4 import BeautifulSoup
5 | import os
6 | import glob
7 | import random
8 | from tqdm import tqdm
9 | import json
10 | import datetime
11 |
12 | URL = "https://www.statmt.org/europarl/"
13 | DATA_DIR = "../../data/europarl"
14 | OUT_DIR = "../../data/europarl/processed"
15 |
16 | def save_to_file(data, fpath):
17 | with open(fpath, "w") as out_file:
18 | for x in data:
19 | out_file.write(json.dumps(x) + "\n")
20 | print(f"Written {len(data)} to {fpath}")
21 |
22 | def process_file(f):
23 | doc = []
24 | with open(f, "r") as in_file:
25 | for line in in_file:
26 | line = line.strip()
27 | if "<" in line or len(line) == 0:
28 | continue
29 | doc.append(line)
30 | return "\n".join(doc)
31 |
32 |
33 | def main():
34 |
35 | # load data
36 | in_glob = os.path.join(DATA_DIR, "txt", "en", "*.txt")
37 | files = glob.glob(in_glob)
38 | print(f"Found {len(files)} files.")
39 |
40 | # parse files
41 | docs = []
42 | for f in tqdm(files):
43 | text = process_file(f)
44 | doc = {
45 | "url": URL,
46 | "created_timestamp" : "", # We don't atually know individual
47 | "downloaded_timestamp" : datetime.date.today().strftime("%m-%d-%Y"),
48 | "text": text
49 | }
50 | docs.append(doc)
51 |
52 | # shuffle and split into train / validation
53 | random.seed(0)
54 | random.shuffle(docs)
55 | train = docs[:int(len(docs)*0.75)]
56 | validation = docs[int(len(docs)*0.75):]
57 |
58 | save_to_file(train, os.path.join(OUT_DIR, "train.euro_parl.jsonl"))
59 | save_to_file(validation, os.path.join(OUT_DIR, "validation.euro_parl.jsonl"))
60 |
61 |
62 |
63 |
64 | if __name__ == "__main__":
65 | main()
--------------------------------------------------------------------------------
/dataset_creation/eurolex/scrape_eurolex.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import random
3 | import json
4 | import datetime
5 | try:
6 | import lzma as xz
7 | except ImportError:
8 | import pylzma as xz
9 |
10 | lex = pd.read_csv("./cache/EurLex_all.csv")
11 |
12 | overwrite = True
13 | open_type = 'w' if overwrite else 'a'
14 | train_f = xz.open("./cache/train.eurlex.xz", open_type)
15 | val_f = xz.open("./cache/validation.eurlex.xz", open_type)
16 |
17 | for act_name, Act_type, subject_matter, date_publication, text in zip(lex["Act_name"], lex["Act_type"], lex["Subject_matter"], lex["Date_publication"], lex["act_raw_text"]):
18 |
19 | datapoint = {
20 | "text" : f"Name: {act_name}\n Type: {Act_type}\n Subject Matter: {subject_matter}\n Date Published: {date_publication}\n\n {text}",
21 | "created_timestamp" : date_publication,
22 | "downloaded_timestamp" : datetime.date.today().strftime("%m-%d-%Y"),
23 | "url" : "https://dataverse.harvard.edu/dataset.xhtml?persistentId=doi:10.7910/DVN/0EGYWY"
24 | }
25 | if random.random() > .75:
26 | val_f.write((json.dumps(datapoint) + "\n").encode("utf-8"))
27 | else:
28 | train_f.write((json.dumps(datapoint) + "\n").encode("utf-8"))
--------------------------------------------------------------------------------
/dataset_creation/federal_register/process_federal_register.py:
--------------------------------------------------------------------------------
1 | # Processes federal register proposed rules (2000 - 2021)
2 | # Bulk data pulled from: https://www.govinfo.gov/bulkdata/xml/FR
3 |
4 | import os
5 | import json
6 | import datetime
7 | import random
8 | from urllib.request import Request, urlopen
9 | # xpath only available in lxml etree, not ElementTree
10 | from lxml import etree
11 | from tqdm import tqdm
12 | from dateutil import parser
13 |
14 | # BASE URL (bulk data API endpoint)
15 | BASE_URL = "https://www.govinfo.gov/bulkdata/xml/FR"
16 |
17 | # DATA DIRS
18 | OUT_DIR = "../../data/federal_register/processed" # path to write data out to
19 |
20 | # Request variables
21 | headers = {'User-Agent': 'Mozilla/5.0', 'Accept': 'application/xml'}
22 |
23 |
24 | def save_to_file(data, out_dir, fname):
25 | if not os.path.exists(out_dir):
26 | os.makedirs(out_dir)
27 | fpath = os.path.join(out_dir, fname)
28 | with open(fpath, 'w') as out_file:
29 | for x in data:
30 | out_file.write(json.dumps(x) + "\n")
31 | print(f"Written {len(data)} to {fpath}")
32 |
33 |
34 | def request_raw_data():
35 | urls = [BASE_URL]
36 | xmls = []
37 | while len(urls) > 0:
38 | next_urls = []
39 | for url in urls:
40 | print(url)
41 | request = Request(url, headers=headers)
42 | with urlopen(request) as response:
43 | root = etree.fromstring(response.read())
44 | elems = root.xpath("*/file[folder='true' and name!='resources']")
45 | if len(elems) > 0:
46 | for e in elems:
47 | next_url = e.find("link").text
48 | next_urls.append(next_url)
49 | else:
50 | elems = root.xpath("*/file[mimeType='application/xml']")
51 | for e in elems:
52 | xml_url = e.find("link").text
53 | request = Request(xml_url, headers=headers)
54 | with urlopen(request) as response:
55 | xml = etree.fromstring(response.read())
56 | # Add tuple of xml_url, xml Element instance
57 | xmls.append((xml_url, xml))
58 | urls = next_urls
59 |
60 | return xmls
61 |
62 | def extract_rule_docs(xmls):
63 | docs = []
64 | for (xml_url, xml) in tqdm(xmls):
65 | print(xml_url)
66 | date = xml.find("DATE").text
67 | creation_date = ""
68 | try:
69 | creation_date = parser.parse(date).strftime("%m-%d-%Y")
70 | except:
71 | pass
72 |
73 | proposed_rules = xml.xpath("PRORULES/PRORULE")
74 | for rule in proposed_rules:
75 | # In Python 3, use encoding='unicode'
76 | # In Python 2, use encoding='utf-8' and decode
77 | all_text = etree.tostring(rule, encoding='unicode', method='text')
78 |
79 | doc = {
80 | "url": xml_url,
81 | "created_timestamp": creation_date,
82 | "downloaded_timestamp": datetime.date.today().strftime("%m-%d-%Y"),
83 | "text": all_text
84 | }
85 | docs.append(doc)
86 |
87 | return docs
88 |
89 |
90 | def main():
91 | # Request raw data directly using bulk data API
92 | xmls = request_raw_data()
93 | docs = extract_rule_docs(xmls)
94 |
95 | # Shuffle and split into train / validation
96 | random.seed(0)
97 | random.shuffle(docs)
98 | train = docs[:int(len(docs)*0.75)]
99 | validation = docs[int(len(docs)*0.75):]
100 |
101 | save_to_file(train, OUT_DIR, "train.federal_register.jsonl")
102 | save_to_file(validation, OUT_DIR, "validation.federal_register.jsonl")
103 |
104 |
105 | if __name__ == '__main__':
106 | main()
--------------------------------------------------------------------------------
/dataset_creation/federal_rules/frcp.py:
--------------------------------------------------------------------------------
1 | import requests
2 | from bs4 import BeautifulSoup
3 | import datetime
4 | import random
5 | import numpy as np
6 | import lzma as xz
7 | import json
8 |
9 | def scrape_rule(url_val):
10 | page = requests.get(url_val)
11 | soup = BeautifulSoup(page.content, "html.parser")
12 | rule_content = soup.find(id="content1")
13 | if rule_content == None:
14 | return ""
15 | remove_text = rule_content.find(id="book-navigation-4940065")
16 | add_title = soup.find(id="page-title")
17 | # removing the bottom navigation titles that get scraped with the rule body
18 | stop_idx = rule_content.text.find(remove_text.text)
19 | # appending the name and editd rule content
20 | final_rule = add_title.text + rule_content.text[:stop_idx].strip()
21 | return final_rule
22 |
23 | # saving the train / val files
24 | def save_final_data(state_json_data, final_path):
25 | with xz.open(final_path, 'w') as state_data:
26 | for state_doc in state_json_data:
27 | state_data.write((json.dumps(state_doc) + "\n").encode("utf-8"))
28 |
29 | basic_url_val = "https://www.law.cornell.edu/rules/frcp/rule_"
30 | rule_idx = [str(x + 1) for x in list(range(86))]
31 | rule_idx.extend(["A", "B", "C", "D", "E", "F", "G"])
32 |
33 | # get the number of rules in an article
34 | docs = []
35 | for rule in rule_idx:
36 | full_rule_url = basic_url_val + rule
37 | rule_content = scrape_rule(full_rule_url)
38 | if rule_content == "":
39 | continue
40 | docs.append({
41 | "url" : full_rule_url,
42 | "text" : rule_content,
43 | "downloaded_timestamp" : datetime.date.today().strftime("%m-%d-%Y"),
44 | "created_timestamp" : ""
45 | })
46 |
47 | # train / test split from process_tax_corpus_jhu.py
48 | random.seed(0)
49 | rand_idx = list(range(len(docs)))
50 | random.shuffle(rand_idx)
51 |
52 | train_idx = rand_idx[:int(len(rand_idx)*0.75)]
53 | val_idx = rand_idx[int(len(rand_idx)*0.75):]
54 |
55 | train_docs = np.array(docs)[train_idx]
56 | val_docs = np.array(docs)[val_idx]
57 |
58 | # saving data
59 | save_final_data(train_docs, "train.frcp.jsonl.xz")
60 | save_final_data(val_docs, "validation.frcp.jsonl.xz")
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
--------------------------------------------------------------------------------
/dataset_creation/federal_rules/fre.py:
--------------------------------------------------------------------------------
1 | import requests
2 | from bs4 import BeautifulSoup
3 | import datetime
4 | import random
5 | import numpy as np
6 | import lzma as xz
7 | import json
8 |
9 | def scrape_rule(url_val):
10 | page = requests.get(url_val)
11 | soup = BeautifulSoup(page.content, "html.parser")
12 | rule_content = soup.find(id="content1")
13 | if rule_content == None:
14 | return ""
15 | remove_text = rule_content.find(id="book-navigation-4940270")
16 | add_title = soup.find(id="page-title")
17 | # removing the bottom navigation titles that get scraped with the rule body
18 | stop_idx = rule_content.text.find(remove_text.text)
19 | # appending the name and editd rule content
20 | final_rule = add_title.text + rule_content.text[:stop_idx].strip()
21 | return final_rule
22 |
23 | # saving the train / val files
24 | def save_final_data(state_json_data, final_path):
25 | with xz.open(final_path, 'w') as state_data:
26 | for state_doc in state_json_data:
27 | state_data.write((json.dumps(state_doc) + "\n").encode("utf-8"))
28 |
29 | basic_url_val = "https://www.law.cornell.edu/rules/fre/"
30 |
31 | # get the number of rules in an article
32 | total_sub_rules_per_rule = []
33 | page = requests.get(basic_url_val)
34 | soup = BeautifulSoup(page.content, "html.parser")
35 | rule_table_of_contents = soup.find(id="content1")
36 | sub_rules = rule_table_of_contents.find_all("ol", class_="bullet")
37 | for sub in sub_rules:
38 | total_rules = len(sub.find_all("li"))
39 | if total_rules <= 15:
40 | total_sub_rules_per_rule.append(total_rules)
41 |
42 | docs = []
43 | for i in range(1, 12):
44 | specific_r_count = total_sub_rules_per_rule[i - 1]
45 | for j in range(1, specific_r_count + 1):
46 | rule_num = i*100 + j
47 | rule_url = basic_url_val + "rule_" + str(rule_num)
48 | rule_content = scrape_rule(rule_url)
49 | if rule_content == "":
50 | continue
51 | docs.append({
52 | "url" : rule_url, #TODO how to pull google drive links
53 | "text" : rule_content,
54 | "downloaded_timestamp" : datetime.date.today().strftime("%m-%d-%Y"),
55 | "created_timestamp" : ""
56 | })
57 |
58 | # train / test split from process_tax_corpus_jhu.py
59 | random.seed(0)
60 | rand_idx = list(range(len(docs)))
61 | random.shuffle(rand_idx)
62 |
63 | train_idx = rand_idx[:int(len(rand_idx)*0.75)]
64 | val_idx = rand_idx[int(len(rand_idx)*0.75):]
65 |
66 | train_docs = np.array(docs)[train_idx]
67 | val_docs = np.array(docs)[val_idx]
68 |
69 | # saving data
70 | save_final_data(train_docs, "train.fre.jsonl.xz")
71 | save_final_data(val_docs, "validation.fre.jsonl.xz")
72 |
73 |
74 |
75 |
76 |
77 |
--------------------------------------------------------------------------------
/dataset_creation/founder_docs/scrape_founding_docs.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | url= "https://founders.archives.gov/Metadata/founders-online-metadata.json"
4 | url_template = "https://www.constituteproject.org/constitution/{const_id}?lang=en"
5 |
6 | from bs4 import BeautifulSoup
7 | from bs4 import BeautifulSoup
8 | import requests_cache
9 | import os
10 | import time
11 | import os
12 | from tqdm import tqdm
13 | import numpy as np
14 | import json
15 | import random
16 | import datetime
17 | requests = requests_cache.CachedSession('scotus')
18 | headers = {'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_10_1) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/39.0.2171.95 Safari/537.36'}
19 | try:
20 | import lzma as xz
21 | except ImportError:
22 | import pylzma as xz
23 |
24 | cache = "./cache"
25 |
26 | if not os.path.exists(cache):
27 | os.mkdir(cache)
28 |
29 |
30 | response = requests.get(url, headers=headers)
31 | ids = [r["permalink"] for r in response.json()]
32 | overwrite = True
33 | open_type = 'w' if overwrite else 'a'
34 | train_f = xz.open("./cache/train.founding_docs.xz", open_type)
35 | val_f = xz.open("./cache/validation.founding_docs.xz", open_type)
36 |
37 | for url in tqdm(ids):
38 | # url = url_template.format(const_id= _id)
39 | tag = url.split("documents/")[-1]
40 | url = f"https://founders.archives.gov/API/docdata/{tag}"
41 | response = requests.get(url, headers=headers)
42 | if not response.from_cache:
43 | time.sleep(.1)
44 |
45 | try:
46 | _dict = response.json()
47 | except:
48 | print("Problem loading response")
49 | print(response.content)
50 | continue
51 | text = f"Title: {_dict['title']}\nFrom: {','.join(_dict['authors'])}\nTo: {','.join(_dict['recipients'])}\n\n{_dict['content']}"
52 | if _dict['date-from'] is not None:
53 | created_date = "-".join(_dict['date-from'].split("-")[1:]) + "-" + _dict['date-from'].split("-")[0]
54 | else:
55 | created_date = ""
56 |
57 | datapoint = {
58 | "text" : text,
59 | "created_timestamp" : created_date,
60 | "downloaded_timestamp" : datetime.date.today().strftime("%m-%d-%Y"),
61 | "url" : url
62 | }
63 |
64 | # This dataset is already heavily US biased, so it would be really weird not to train on the US
65 | # constitution
66 | if random.random() > .75:
67 | val_f.write((json.dumps(datapoint) + "\n").encode("utf-8"))
68 | else:
69 | train_f.write((json.dumps(datapoint) + "\n").encode("utf-8"))
70 |
71 |
--------------------------------------------------------------------------------
/dataset_creation/ftcadvisories/scrape_ftc_advisories.py:
--------------------------------------------------------------------------------
1 | from bs4 import BeautifulSoup
2 | from bs4 import BeautifulSoup
3 | import requests_cache
4 | import os
5 | import time
6 | import os
7 | import textract
8 | import numpy as np
9 | import json
10 | import random
11 | import datetime
12 | requests = requests_cache.CachedSession('scotus')
13 | headers = {'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_10_1) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/39.0.2171.95 Safari/537.36'}
14 | try:
15 | import lzma as xz
16 | except ImportError:
17 | import pylzma as xz
18 |
19 | cache = "./cache"
20 |
21 | if not os.path.exists(cache):
22 | os.mkdir(cache)
23 | url_template = "https://www.ftc.gov/policy/advisory-opinions?page={page}"
24 |
25 |
26 | overwrite = True
27 | open_type = 'w' if overwrite else 'a'
28 | train_f = xz.open("./cache/train.ftc_advisory_opinions.xz", open_type)
29 | val_f = xz.open("./cache/validation.ftc_advisory_opinions.xz", open_type)
30 |
31 | for page in range(0, 20):
32 | print(page)
33 |
34 | url = url_template.format(page = page)
35 | response = requests.get(url, headers=headers)
36 | soup= BeautifulSoup(response.text, "html.parser")
37 | potential_links = soup.select("a[href$='.pdf']")
38 | # import pdb; pdb.set_trace()
39 | potential_links = [link for link in potential_links if "advisory_opinions" in link["href"]]
40 | # potential_dates = [d.get_text() for d in soup.find_all("span", {"class" : "date-display-single"})]
41 |
42 | # assert len(potential_dates) == len(potential_links)
43 | if len(potential_links) == 0:
44 | print(f"SKipping for year {page} because no pdf links")
45 | continue
46 |
47 | for link in potential_links:
48 | try:
49 | response = requests.get(link["href"], headers=headers)
50 | except:
51 | print(f"PROBLEM GETTING {link}")
52 | time.sleep(random.random()*5)
53 | continue
54 |
55 | if not response.from_cache:
56 | time.sleep(random.random()*2.)
57 |
58 |
59 | with open(f'{cache}/{link["href"].split("/")[-1]}', 'wb') as f:
60 | f.write(response.content)
61 |
62 | try:
63 | text = str(textract.process(f'{cache}/{link["href"].split("/")[-1]}'))
64 | except:
65 | print(f"Problem with {link['href']}")
66 | continue
67 | if len(text) < len('\\x0c\\x0c') * 2:
68 | try:
69 | text = str(textract.process(f'{cache}/{link["href"].split("/")[-1]}', method='tesseract'))
70 | except:
71 | print(f"Problem with {link['href']}")
72 | continue
73 | if len(text) < len('\\x0c\\x0c') * 2:
74 | continue
75 |
76 | os.remove(f'{cache}/{link["href"].split("/")[-1]}')
77 | datapoint = {
78 | "text" : text,
79 | "created_timestamp" : "",
80 | "downloaded_timestamp" : datetime.date.today().strftime("%m-%d-%Y"),
81 | "url" : link["href"]
82 | }
83 | # import pdb; pdb.set_trace()
84 | if random.random() > .75:
85 | val_f.write((json.dumps(datapoint) + "\n").encode("utf-8"))
86 | else:
87 | train_f.write((json.dumps(datapoint) + "\n").encode("utf-8"))
--------------------------------------------------------------------------------
/dataset_creation/oig_reports/scrape_oig_reports.py:
--------------------------------------------------------------------------------
1 | url = "https://archive.org/download/us-inspectors-general.bulk/us-inspectors-general.bulk.zip"
2 |
3 | import zipfile
4 | import os
5 | import datetime
6 | import random
7 | try:
8 | import lzma as xz
9 | except ImportError:
10 | import pylzma as xz
11 |
12 | cache_path = "./cache/us-inspectors-general.bulk.zip"
13 | if not os.path.exists(cache_path):
14 | import wget
15 | filename = wget.download(url, out="./cache/")
16 |
17 | import os
18 | import zipfile
19 | overwrite = True
20 | open_type = 'w' if overwrite else 'a'
21 | train_f = xz.open("./cache/train.oig.xz", open_type)
22 | val_f = xz.open("./cache/validation.oig.xz", open_type)
23 | import os, json
24 |
25 | i = 0
26 | with zipfile.ZipFile(cache_path) as z:
27 | for filename in z.namelist():
28 | if not os.path.isdir(filename) and ".txt" in filename:
29 | # read the file
30 | with z.open(filename) as f:
31 | year = filename.split("/")[1]
32 | text = str(f.read())
33 |
34 | datapoint = {
35 | "url" : url,
36 | "text" : text,
37 | "downloaded_timestamp" : datetime.date.today().strftime("%m-%d-%Y"),
38 | "created_timestamp" : year
39 | }
40 |
41 | if random.random() > .75:
42 | val_f.write((json.dumps(datapoint) + "\n").encode("utf-8"))
43 | else:
44 | train_f.write((json.dumps(datapoint) + "\n").encode("utf-8"))
45 | i += 1
46 |
47 | if i % 5000 == 0: print(i)
--------------------------------------------------------------------------------
/dataset_creation/olc/scrape_olc_memos.py:
--------------------------------------------------------------------------------
1 | from bs4 import BeautifulSoup
2 | import requests
3 | import os
4 | import time
5 | import os
6 | import textract
7 | import numpy as np
8 | import json
9 | import random
10 | import datetime
11 |
12 | try:
13 | import lzma as xz
14 | except ImportError:
15 | import pylzma as xz
16 |
17 | cache = "./cache"
18 |
19 | if not os.path.exists(cache):
20 | os.mkdir(cache)
21 | base_url = "https://www.justice.gov/olc/opinions"
22 |
23 | pages = 138
24 | docs = []
25 |
26 | def _get_opinions(soup):
27 | g_data = soup.findAll("td", {"class": "views-field-field-opinion-attachment-file"})
28 | g_dates = soup.findAll("span", {"class" : "date-display-single"})
29 | for data, date in zip(g_data, g_dates):
30 | link = "https://www.justice.gov/" + data.find("a")["href"]
31 | tag = link.split("/")[-2]
32 | if not os.path.exists(f'{cache}/{tag}.pdf'):
33 | print(link)
34 | pdf = requests.get(link)
35 |
36 | with open(f'{cache}/{tag}.pdf', 'wb') as f:
37 | f.write(pdf.content)
38 |
39 | text = str(textract.process(f'{cache}/{tag}.pdf'))
40 |
41 | issuance_date = date.text
42 | print(issuance_date)
43 | issuance_date = issuance_date.replace("/", "-")
44 | docs.append({
45 | "text" : text,
46 | "created_timestamp" : issuance_date,
47 | "downloaded_timestamp" : datetime.date.today().strftime("%m-%d-%Y"),
48 | "url" : link
49 | })
50 |
51 |
52 | pages = 138
53 | page2 = requests.get(base_url)
54 |
55 | soup = BeautifulSoup(page2.text, "lxml")
56 | button_next = soup.find("a", {"title": "Go to next page"}, href=True)
57 |
58 | _get_opinions(soup)
59 |
60 | while button_next:
61 | time.sleep(2)#delay time requests are sent so we don\'t get kicked by server
62 | url2 = "https://www.justice.gov{0}".format(button_next["href"])
63 | page2=requests.get(url2)
64 | soup=BeautifulSoup(page2.text,"lxml")
65 | _get_opinions(soup)
66 | button_next = soup.find("a", {"title": "Go to next page"}, href=True)
67 |
68 | def save_to_processed(train, val, source_name, out_path):
69 | if not os.path.exists(out_path):
70 | os.makedirs(out_path)
71 | tf = os.path.join(out_path, f"train.{source_name}.jsonl")
72 | with open(tf, mode='w', encoding='utf-8') as out_file:
73 | for line in train:
74 | out_file.write(json.dumps(line) + "\n")
75 | print(f"Written {len(train)} documents to {tf}")
76 |
77 | vf = os.path.join(out_path, f"validation.{source_name}.jsonl")
78 | with open(vf, mode='w', encoding='utf-8') as out_file:
79 | for line in val:
80 | out_file.write(json.dumps(line) + "\n")
81 | print(f"Written {len(val)} documents to {vf}")
82 |
83 | # now compress with lib
84 | print("compressing files...")
85 | with open(vf, 'rb') as f, open(vf+".xz", 'wb') as out:
86 | out.write(xz.compress(bytes(f.read())))
87 | with open(tf, 'rb') as f, open(tf+".xz", 'wb') as out:
88 | out.write(xz.compress(bytes(f.read())))
89 | print("compressed")
90 |
91 | random.seed(0) # important for shuffling
92 | rand_idx = list(range(len(docs)))
93 | random.shuffle(rand_idx)
94 |
95 | train_idx = rand_idx[:int(len(rand_idx)*0.75)]
96 | val_idx = rand_idx[int(len(rand_idx)*0.75):]
97 |
98 | train_docs = np.array(docs)[train_idx]
99 | val_docs = np.array(docs)[val_idx]
100 |
101 | save_to_processed(train_docs, val_docs, "olcmemos", "./cache/")
102 |
103 |
--------------------------------------------------------------------------------
/dataset_creation/process_cfr.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Processes CFR XML files.
3 | # Raw XML can be downloaded from: https://www.govinfo.gov/bulkdata/CFR/2020/CFR-2020.zip
4 |
5 | import xml.etree.ElementTree as ET
6 | import glob
7 | import os
8 | import json
9 | import random
10 | from tqdm import tqdm
11 | random.seed(0) # important for shuffling
12 |
13 | # FILE PATHS
14 | RAW_PATH = "../data/cfr/raw" # path to title folders
15 | OUT_PATH = "../data/cfr/processed" # path to write data out to
16 |
17 |
18 | files = glob.glob(os.path.join(RAW_PATH, "title-*/*.xml"))
19 | print(f"{len(files)} total documents.")
20 | outputs = []
21 | for file in tqdm(files):
22 | doc_name = file.split("/")[-1]
23 | tree = ET.parse(file)
24 | all_text = ET.tostring(tree.getroot(), encoding='utf-8', method='text')
25 | outputs.append({
26 | "text": all_text.decode("utf-8"),
27 | "url": "https://www.govinfo.gov/bulkdata/CFR/2020/CFR-2020.zip",
28 | "timestamp": "08-17-2021"
29 | })
30 |
31 | random.shuffle(outputs)
32 | train = outputs[:int(len(files)*0.75)]
33 | val = outputs[int(len(files)*0.75):]
34 |
35 | # save to processed
36 | tf = os.path.join(OUT_PATH, "train.cfr.jsonl")
37 | with open(tf, "w") as out_file:
38 | for o in train:
39 | out_file.write(json.dumps(o) + "\n")
40 | print(f"Written {len(train)} documents to {tf}")
41 |
42 | vf = os.path.join(OUT_PATH, "validation.cfr.jsonl")
43 | with open(vf, "w") as out_file:
44 | for o in val:
45 | out_file.write(json.dumps(o) + "\n")
46 | print(f"Written {len(val)} documents to {vf}")
--------------------------------------------------------------------------------
/dataset_creation/process_scotus_oral_arguments.py:
--------------------------------------------------------------------------------
1 | # Script for extracting transcript data from SCOTUS Oral Arguments.
2 | # Source: https://github.com/walkerdb/supreme_court_transcripts/releases/tag/2021-08-14
3 | # We prepend the speaker to each line of text. Where the speaker is unknown, we use "Speaker"
4 |
5 | import glob
6 | import json
7 | import os
8 | import random
9 | from tqdm import tqdm
10 | from dateutil import parser
11 | import re
12 | import lzma
13 | import datetime# import datetime
14 |
15 |
16 | URL = "https://github.com/walkerdb/supreme_court_transcripts/releases/tag/2021-08-14"
17 | IN_DIR = "../../data/scotus_oral/supreme_court_transcripts-2021-08-14/oyez/cases/"
18 | OUT_DIR = "../../data/scotus_oral/processed"
19 |
20 | def process_transcript(d):
21 | doc = []
22 | try:
23 | for el in d["transcript"]['sections']:
24 | for turn in el['turns']:
25 | speaker = "Speaker"
26 | if turn['speaker'] is not None:
27 | speaker = turn['speaker']['name']
28 | text = [t['text'] for t in turn['text_blocks']]
29 | text = " ".join(text)
30 | text = f"{speaker}: {text}"
31 | doc.append(text)
32 | except:
33 | print(d)
34 | return "\n".join(doc)
35 |
36 | def save_to_file(data, fpath):
37 | with open(fpath, "w") as out_file:
38 | for x in data:
39 | out_file.write(json.dumps(x) + "\n")
40 | print(f"Written {len(data)} to {fpath}")
41 |
42 | def main():
43 |
44 | in_files = glob.glob(os.path.join(IN_DIR, "*t*.json"))
45 |
46 | docs = []
47 | for f in tqdm(in_files):
48 | data = json.load(open(f))
49 | if data['unavailable'] or data['transcript'] is None:
50 | continue
51 | text = process_transcript(data)
52 |
53 | # Get oral argument date
54 | title = data['title']
55 | match = re.search(r'(Jan(uary)?|Feb(ruary)?|Mar(ch)?|Apr(il)?|May|Jun(e)?|Jul(y)?|Aug(ust)?|Sep(tember)?|Oct(ober)?|Nov(ember)?|Dec(ember)?)\s+\d{2},\s+\d{4}', title)
56 | oral_date = parser.parse(match.group()).strftime("%m-%d-%Y")
57 |
58 | doc = {
59 | "url": URL,
60 | "created_timestamp": oral_date,
61 | "timestamp": datetime.date.today().strftime("%m-%d-%Y"),
62 | "text": text
63 | }
64 | docs.append(doc)
65 |
66 | # shuffle and split into train / validation
67 | random.seed(0)
68 | random.shuffle(docs)
69 | train = docs[:int(len(docs)*0.75)]
70 | validation = docs[int(len(docs)*0.75):]
71 |
72 | save_to_file(train, os.path.join(OUT_DIR, "train.scotus_oral.jsonl"))
73 | save_to_file(validation, os.path.join(OUT_DIR, "validation.scotus_oral.jsonl"))
74 |
75 | if __name__ == "__main__":
76 | main()
--------------------------------------------------------------------------------
/dataset_creation/r_legaladvice/scrape_r_legaladvice.py:
--------------------------------------------------------------------------------
1 | import requests
2 | import time
3 | import random
4 | import datetime
5 | import json
6 | try:
7 | import lzma as xz
8 | except ImportError:
9 | import pylzma as xz
10 | subreddits = ['legaladvice', 'legaladviceofftopic']
11 | maxThings = -1
12 | printWait = 2
13 | requestSize = 100
14 | from profanity_check import predict, predict_prob
15 |
16 |
17 | def requestJSON(url):
18 | while True:
19 | try:
20 | r = requests.get(url)
21 | if r.status_code != 200:
22 | print('error code', r.status_code)
23 | time.sleep(5)
24 | continue
25 | else:
26 | break
27 | except Exception as e:
28 | print(e)
29 | time.sleep(5)
30 | continue
31 | return r.json()
32 | overwrite = False
33 | open_type = 'w' if overwrite else 'a'
34 | train_f = xz.open("./cache/train.r_legaladvice.xz", open_type)
35 | val_f = xz.open("./cache/validation.r_legaladvice.xz", open_type)
36 |
37 | for subreddit in subreddits:
38 |
39 | meta = requestJSON('https://api.pushshift.io/meta')
40 | limitPerMinute = meta['server_ratelimit_per_minute']
41 | requestWait = 60 / limitPerMinute
42 |
43 | print('server_ratelimit_per_minute', limitPerMinute)
44 |
45 |
46 |
47 | things = ('submission',)
48 | def get_comments_for_post(post_id):
49 | url = 'https://api.pushshift.io/reddit/comment/search?link_id='\
50 | + post_id \
51 | + f'&metadata=true&limit=20000&subreddit={subreddit}'
52 | jsond = requestJSON(url)
53 | time.sleep(requestWait)
54 | return jsond['data']
55 |
56 | for thing in things:
57 | i = 0
58 |
59 | print('\n[starting', thing + 's]')
60 |
61 | if maxThings < 0:
62 |
63 | url = 'https://api.pushshift.io/reddit/search/'\
64 | + thing + '/?subreddit='\
65 | + subreddit\
66 | + '&metadata=true&size=0'
67 |
68 | jsond = requestJSON(url)
69 |
70 | totalResults = jsond['metadata']['total_results']
71 | print('total ' + thing + 's', 'in', subreddit,':', totalResults)
72 | else:
73 | totalResults = maxThings
74 | print('downloading most recent', maxThings)
75 |
76 |
77 | created_utc = '1612822911' if subreddit == 'legaladvice' else ''
78 |
79 | startTime = time.time()
80 | timePrint = startTime
81 | while True:
82 | url = 'http://api.pushshift.io/reddit/search/'\
83 | + thing + '/?subreddit=' + subreddit\
84 | + '&size=' + str(requestSize)\
85 | + '&before=' + str(created_utc)
86 |
87 | jsond = requestJSON(url)
88 |
89 | if len(jsond['data']) == 0:
90 | break
91 |
92 | doneHere = False
93 | for post in jsond['data']:
94 | created_utc = post["created_utc"]
95 | # f.write(str(post) + '\n')
96 | i += 1
97 | comments = get_comments_for_post(post['id'])
98 |
99 | filtered_comments = [comment for comment in comments if comment['score'] >= 8 \
100 | and predict_prob([comment['body']])[0] < .8 \
101 | and '[removed]' \
102 | not in comment['body'] \
103 | and '[deleted]' \
104 | not in comment['body'] \
105 | and post['id'] in comment['parent_id']
106 | ]
107 |
108 | if len(filtered_comments) == 0:
109 | continue
110 |
111 | if 'selftext' in post:
112 | selector = 'selftext'
113 | elif 'text' in post:
114 | selector = 'text'
115 | elif 'body' in post:
116 | selector = 'body'
117 | else:
118 | print("Couldn't find a selector for post text, continuing")
119 | print(post)
120 | continue
121 | post_text = post[selector]
122 | text = f"Title: {post['title']}\nQuestion:{post_text}\n"
123 | if 'link_flair_text' in post and post['link_flair_text']:
124 | text += f"Topic:\n{post['link_flair_text']}\n"
125 | elif 'link_flair_richtext' in post and post['link_flair_richtext']:
126 | text += f"Topic:\n{post['link_flair_text']}\n"
127 |
128 | for q, comment in enumerate(sorted(filtered_comments, key=lambda x: x['score'], reverse=True)):
129 | text += f'Answer #{q+1}: {comment["body"]}'
130 |
131 | datapoint = {
132 | "url" : post['url'],
133 | "text" : text,
134 | "downloaded_timestamp" : datetime.date.today().strftime("%m-%d-%Y"),
135 | "created_timestamp" : datetime.datetime.fromtimestamp(created_utc).strftime("%m-%d-%Y")
136 | }
137 | if random.random() > .75:
138 | val_f.write((json.dumps(datapoint) + "\n").encode("utf-8"))
139 | else:
140 | train_f.write((json.dumps(datapoint) + "\n").encode("utf-8"))
141 |
142 | if i >= totalResults:
143 | doneHere = True
144 | break
145 |
146 | if doneHere:
147 | break
148 |
149 | if time.time() - timePrint > printWait:
150 | timePrint = time.time()
151 | percent = i / totalResults * 100
152 |
153 | timePassed = time.time() - startTime
154 |
155 | print('{:.2f}'.format(percent) + '%', '|',
156 | time.strftime("%H:%M:%S", time.gmtime(timePassed)))
157 |
158 |
159 | time.sleep(requestWait)
160 |
161 | train_f.close()
162 | val_f.close()
163 |
--------------------------------------------------------------------------------
/dataset_creation/scotus_dockets/scrape_scotus_dockets.py:
--------------------------------------------------------------------------------
1 | from bs4 import BeautifulSoup
2 | from bs4 import BeautifulSoup
3 | import requests_cache
4 | import os
5 | import time
6 | import os
7 | import textract
8 | import numpy as np
9 | import json
10 | import random
11 | import datetime
12 | requests = requests_cache.CachedSession('scotus')
13 | headers = {'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_10_1) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/39.0.2171.95 Safari/537.36'}
14 | try:
15 | import lzma as xz
16 | except ImportError:
17 | import pylzma as xz
18 |
19 | cache = "./cache"
20 |
21 | if not os.path.exists(cache):
22 | os.mkdir(cache)
23 | url_template_default = "https://www.supremecourt.gov/search.aspx?filename=/docket/docketfiles/html/public/{year}-{n}.html"
24 | aurl_template = "https://www.supremecourt.gov/search.aspx?filename=/docket/docketfiles/html/public/{year}A{n}.html"
25 |
26 |
27 | overwrite = True
28 | open_type = 'w' if overwrite else 'a'
29 | train_f = xz.open("./cache/train.scotus_docket_entries.xz", open_type)
30 | val_f = xz.open("./cache/validation.scotus_docket_entries.xz", open_type)
31 |
32 | for year in list(range(18, 22))[::-1]:
33 |
34 | start_ns = [0, 5001, "A"]
35 | # print(year)
36 | # if year == 21:
37 | # vals = list(range(1,558)) + list(range(5001))
38 | # elif year == 20:
39 | # vals = range(1, 8478)
40 | # elif year == 19:
41 | # vals = range(1, 8930)
42 | # elif year == 18:
43 | # vals = range(1, 9849)
44 | # else:
45 | # raise ValueError("blerg")
46 |
47 | for start_n in start_ns:
48 | if start_n == "A":
49 | start_n = 0
50 | url_template = aurl_template
51 | else:
52 | url_template = url_template_default
53 |
54 | no_link_count = 0
55 | for n in range(start_n, start_n+5000):
56 | url = url_template.format(year=year, n=n)
57 | response = requests.get(url, headers=headers)
58 | soup= BeautifulSoup(response.text, "html.parser")
59 | potential_links = soup.select("a[href$='.pdf']")
60 | # import pdb; pdb.set_trace()
61 | potential_links = [link for link in potential_links if "DocketPDF" in link["href"]]
62 |
63 | if len(potential_links) == 0:
64 | print(f"SKipping for year {year} at n {n} because no pdf links")
65 | no_link_count += 1
66 | if no_link_count <= 10:
67 | continue
68 | else:
69 | break
70 |
71 | no_link_count = 0
72 |
73 | for link in potential_links:
74 | try:
75 | response = requests.get(link["href"], headers=headers)
76 | except:
77 | print(f"PROBLEM GETTING {link}")
78 | time.sleep(random.random()*5)
79 | continue
80 |
81 | if not response.from_cache:
82 | time.sleep(random.random()*2.)
83 |
84 | filepath = f'{cache}/{link["href"].split("/")[-1]}'
85 | if len(filepath) > 200:
86 | import uuid
87 | filepath = f'{cache}/{str(uuid.uuid4())}'
88 |
89 | with open(filepath, 'wb') as f:
90 | f.write(response.content)
91 |
92 | try:
93 | text = str(textract.process(filepath))
94 | except:
95 | print(f"Problem with {link['href']}")
96 | continue
97 | if len(text) < len('\\x0c\\x0c') * 2:
98 | try:
99 | text = str(textract.process(filepath, method='tesseract'))
100 | except:
101 | print(f"Problem with {link['href']}")
102 | continue
103 | if len(text) < len('\\x0c\\x0c') * 2:
104 | continue
105 |
106 | os.remove(filepath)
107 | datapoint = {
108 | "text" : text,
109 | "created_timestamp" : f'{link["href"].split("/")[-1][4:6]}-{link["href"].split("/")[-1][6:8]}-{link["href"].split("/")[-1][0:4]}',
110 | "downloaded_timestamp" : datetime.date.today().strftime("%m-%d-%Y"),
111 | "url" : link["href"]
112 | }
113 | if random.random() > .75:
114 | val_f.write((json.dumps(datapoint) + "\n").encode("utf-8"))
115 | else:
116 | train_f.write((json.dumps(datapoint) + "\n").encode("utf-8"))
--------------------------------------------------------------------------------
/dataset_creation/scrape_nlrb_decisions.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import requests
3 | import re
4 | from tqdm import tqdm
5 | import wget
6 | import os
7 | import json
8 | import PyPDF2
9 | import glob
10 | import datetime
11 | import random
12 |
13 | from multiprocessing import Pool
14 |
15 |
16 | DIR = "../data/nlrb/scraping"
17 | OUT_DIR = "../data/nlrb/processed"
18 | BASE_URL = "https://www.nlrb.gov/cases-decisions/decisions/board-decisions?search_term=&op=Search&volume=-1&slip_opinion_number=&page_number=$PAGE_NUMBER&items_per_page=100&form_build_id=form-EXFTGMwfIM0yO2L6ENa30_RuqwbWvmk5YR1UYcvgqsA&form_id=board_decisions_form"
19 |
20 |
21 | def get_urls(args):
22 | all_matches = []
23 | for page_num in tqdm(range(661)):
24 | url = BASE_URL.replace("$PAGE_NUMBER", str(page_num))
25 | page = requests.get(url)
26 | text = page.text
27 | #p = re.compile("href=\"(https\:\/\/apps\.nlrb\.gov\/link\/document\.aspx\/[a-z0-9]*)\"")
28 | p = re.compile("
([0-9]*\/[0-9]*\/[0-9]*) <\/td>\s* | [0-9]* NLRB No\. [0-9]*<\/td>\s* | \s* Sections -> section format
107 | # Tennesse has Chapters -> Chapter -> Sections -> section
108 | # some others have just Sections -> section (no Chapters)
109 | # alaska uses secs instead of sections
110 | # TO DO: montana has parts and part
111 |
112 | # some have sects instead of sections
113 | all_text = _unroll(file_content_dict)
114 |
115 |
116 | # if (isinstance(file_content_dict, list)) and (isinstance(file_content_dict[0], dict)):
117 | # top_keys = file_content_dict[0].keys()
118 | # if "link" in top_keys:
119 | # needed_link = file_content_dict[0]["link"] # for output
120 | # if "chapters" in top_keys:
121 | # # go to one level down look for sections
122 | # # multiple sections - need to parse here
123 | # for sect_elem in file_content_dict[0]["chapters"]:
124 | # sub_keys = sect_elem.keys()
125 | # if not sect_elem:
126 | # print (f" *** Empty dict hit -- {ind_file}")
127 | # continue
128 | # if "raws" in sub_keys:
129 | # to_parse_sections.append(sect_elem)
130 | # continue
131 |
132 | # if "sections" in sub_keys:
133 | # to_parse_sections = sect_elem["sections"]
134 | # elif "secs" in sub_keys:
135 | # to_parse_sections = sect_elem["secs"]
136 | # else:
137 | # print (f"*** subkeys {sub_keys} ")
138 | # print (f"*** ERROR - No Sections under Chapters in {ind_file} ")
139 | # elif "sections" in top_keys:
140 | # to_parse_sections = file_content_dict[0]["sections"]
141 |
142 | # to_parse_sections has the list of all sections, combine texts from all section elements
143 |
144 | # some section (especially the last one) is empty - check for these
145 | # for ind_sect in to_parse_sections:
146 | # if ind_sect:
147 | # if "raws" in ind_sect.keys():
148 | # all_text = all_text + " "+ " ".join(ind_sect['raws'])
149 |
150 | # create dictionary for output
151 | out_dict = {}
152 | out_dict ["url"] = url # should the key be url or link
153 | out_dict ["text"] = all_text
154 | #out_dict ["created_timestamp"] = TIME_STAMP
155 | out_dict ["created_timestamp"] = ind_file.split("_")[-1].replace(".json", "")
156 | out_dict ["downloaded_timestamp"] = datetime.date.today().strftime("%m-%d-%Y")
157 | out_dict ["state_year"] = state_name_year
158 |
159 | if random.random() > .75:
160 | val_f.write((json.dumps(out_dict) + "\n").encode("utf-8"))
161 | else:
162 | train_f.write((json.dumps(out_dict) + "\n").encode("utf-8"))
163 |
--------------------------------------------------------------------------------
/dataset_creation/state_codes/state_codes_from_scratch.py:
--------------------------------------------------------------------------------
1 | import requests
2 | from bs4 import BeautifulSoup
3 | import datetime
4 | import random
5 | import numpy as np
6 | import lzma as xz
7 | import json
8 |
9 | # The other file was run to grab the state codes. This captures
10 | # all the states with no errors, but takes way too long to run (timing out). If
11 | # someone has capacity to get fix the bandwidth issues on this, feel free to take a stab!
12 |
13 | def state_code_scrape(url_val):
14 | page = requests.get(url_val)
15 | soup = BeautifulSoup(page.content, "html.parser")
16 | results = soup.find(id="codes-content")
17 | return results.text
18 |
19 | def year_list_scrape(url_val):
20 | page = requests.get(url_val)
21 | soup = BeautifulSoup(page.content, "html.parser")
22 | results = soup.find(id="main-content")
23 | code_years = results.find("ul")
24 | year_list = [year.text[:4] for year in code_years]
25 | return year_list
26 |
27 | # saving the train / val files
28 | def save_final_data(state_json_data, final_path):
29 | with xz.open(final_path, 'w') as state_data:
30 | for state_doc in state_json_data:
31 | state_data.write((json.dumps(state_doc) + "\n").encode("utf-8"))
32 |
33 | def create_nested_doc(url_val, base_url):
34 | final_text = ""
35 | page = requests.get(url_val)
36 | soup = BeautifulSoup(page.content, "html.parser")
37 | title_sections = soup.find("div", class_="codes-listing")
38 | if title_sections == None:
39 | return state_code_scrape(url_val)
40 | else:
41 | title_section = title_sections.find_all("li")
42 | for sec in title_section:
43 | final_link = ""
44 | further_sec = sec.find_all('a', href=True)
45 | for l in further_sec:
46 | final_link = l['href']
47 | if final_link != "":
48 | sec_url = base_url + final_link
49 | final_text += create_nested_doc(sec_url, base_url)
50 | return final_text
51 |
52 | state_names = ["Alaska", "Alabama", "Arkansas", "Arizona", "California", "Colorado", "Connecticut", "District-of-columbia", "Delaware", "Florida", "Georgia", "Guam", "Hawaii", "Iowa", "Idaho", "Illinois", "Indiana", "Kansas", "Kentucky", "Louisiana", "Massachusetts", "Maryland", "Maine", "Michigan", "Minnesota", "Missouri", "Mississippi", "Montana", "North-carolina", "North-dakota", "Nebraska", "New-hampshire", "New-jersey", "New-mexico", "Nevada", "New-york", "Ohio", "Oklahoma", "Oregon", "Pennsylvania", "Puerto-rico", "Rhode-island", "South-carolina", "South-dakota", "Tennessee", "Texas", "Utah", "Virginia", "Virgin-Islands", "Vermont", "Washington", "Wisconsin", "West-virginia", "Wyoming"]
53 | state_names = [x.lower() for x in state_names]
54 |
55 | base_url = "https://law.justia.com"
56 | docs = []
57 | for state in state_names:
58 | state_base_url = base_url + '/codes/' + state
59 | year_list = year_list_scrape(state_base_url)
60 | for year in year_list:
61 | constructed_doc = ""
62 | year_url = state_base_url + '/' + year
63 | constructed_doc = create_nested_doc(year_url, base_url)
64 | docs.append({
65 | "url" : year_url, #TODO how to pull google drive links
66 | "text" : constructed_doc,
67 | "downloaded_timestamp" : datetime.date.today().strftime("%m-%d-%Y"),
68 | "created_timestamp" : "",
69 | "state_year": state + '_' + year
70 | })
71 | print(constructed_doc)
72 |
73 | # train / test split from process_tax_corpus_jhu.py
74 | random.seed(0)
75 | rand_idx = list(range(len(docs)))
76 | random.shuffle(rand_idx)
77 |
78 | train_idx = rand_idx[:int(len(rand_idx)*0.75)]
79 | val_idx = rand_idx[int(len(rand_idx)*0.75):]
80 |
81 | train_docs = np.array(docs)[train_idx]
82 | val_docs = np.array(docs)[val_idx]
83 |
84 | # saving data
85 | save_final_data(train_docs, "train.state_codes.jsonl.xz")
86 | save_final_data(val_docs, "validation.state_codes.jsonl.xz")
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
--------------------------------------------------------------------------------
/dataset_creation/tax_corpus_jhu/process_tax_corpus_jhu.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import json
4 | import datetime
5 | import numpy as np
6 | # Assumes you've downloaded
7 | link = "https://archive.data.jhu.edu/file.xhtml?persistentId=doi:10.7281/T1/N1X6I4/D5CQ0Y&version=2.0"
8 | # into cache/ and extracted it
9 | try:
10 | import lzma as xz
11 | except ImportError:
12 | import pylzma as xz
13 |
14 | with open('cache/plrs_tc_corpus_feb25.txt', 'r') as f:
15 | docs = [{
16 | "text" : " ".join(doc.split(" ")[1:]),
17 | "created_timestamp" : "",
18 | "downloaded_timestamp" : datetime.date.today().strftime("%m-%d-%Y"),
19 | "url" : link
20 | } for doc in f.readlines()] # first word is always a file name
21 |
22 | # Note the preprocessing in this dataset is very messy. In the future we may wish to find the raw original proceedings.
23 |
24 | def save_to_processed(train, val, source_name, out_path):
25 | if not os.path.exists(out_path):
26 | os.makedirs(out_path)
27 | tf = os.path.join(out_path, f"train.{source_name}.jsonl")
28 | with open(tf, mode='w', encoding='utf-8') as out_file:
29 | for line in train:
30 | out_file.write(json.dumps(line) + "\n")
31 | print(f"Written {len(train)} documents to {tf}")
32 |
33 | vf = os.path.join(out_path, f"validation.{source_name}.jsonl")
34 | with open(vf, mode='w', encoding='utf-8') as out_file:
35 | for line in val:
36 | out_file.write(json.dumps(line) + "\n")
37 | print(f"Written {len(val)} documents to {vf}")
38 |
39 | # now compress with lib
40 | print("compressing files...")
41 | with open(vf, 'rb') as f, open(vf+".xz", 'wb') as out:
42 | out.write(xz.compress(bytes(f.read())))
43 | with open(tf, 'rb') as f, open(tf+".xz", 'wb') as out:
44 | out.write(xz.compress(bytes(f.read())))
45 | print("compressed")
46 |
47 | random.seed(0) # important for shuffling
48 | rand_idx = list(range(len(docs)))
49 | random.shuffle(rand_idx)
50 |
51 | train_idx = rand_idx[:int(len(rand_idx)*0.75)]
52 | val_idx = rand_idx[int(len(rand_idx)*0.75):]
53 |
54 | train_docs = np.array(docs)[train_idx]
55 | val_docs = np.array(docs)[val_idx]
56 |
57 | save_to_processed(train_docs, val_docs, "taxrulings", "./cache/")
58 |
--------------------------------------------------------------------------------
/dataset_creation/tos/process_tos.py:
--------------------------------------------------------------------------------
1 | # Processes TOS XML files.
2 | # XML with annotation tags (clause type / whether clause is unfair) can be downloaded from: http://claudette.eui.eu/ToS.zip
3 | # XML files located in OriginalTaggedDocuments subdirectory
4 |
5 | import os
6 | import json
7 | import re
8 | import random
9 | import numpy as np
10 | import datetime
11 | from tqdm import tqdm
12 | from dateutil import parser
13 | from bs4 import BeautifulSoup
14 |
15 | # FILE PATHS
16 | RAW_PATH = "../../data/tos/raw" # path to contracts
17 | OUT_PATH = "../../data/tos/processed" # path to write data (docs of original text) out to
18 | OUT_PATH_TAGGED = "../../data/tos/processed_tagged" # path to write data (docs of original text with annotation tags) out to
19 |
20 | # Regex match for date
21 | match_text_patterns = ["Last Updated Date:", "Last Updated:", "Last updated:", "last updated on", "Last revised on", "Last Revised:", "Last revised on", "Revised", "Date of Last Revision:", "Last modified:", "Last modified by", "in effect as of", "Effective Date:", "Effective:", "effective on", "Effective on", "Applicable from:"]
22 | combined_text_regex = '(?:%s)' % '|'.join(match_text_patterns)
23 |
24 | match_date_patterns = [
25 | r'(Jan(uary)?|Feb(ruary)?|Mar(ch)?|Apr(il)?|May|Jun(e)?|Jul(y)?|Aug(ust)?|Sep(tember)?|Oct(ober)?|Nov(ember)?|Dec(ember)?)\s+\d{1,2}(th)?,\s+\d{4}',
26 | r'\d{1,2}\s+(Jan(uary)?|Feb(ruary)?|Mar(ch)?|Apr(il)?|May|Jun(e)?|Jul(y)?|Aug(ust)?|Sep(tember)?|Oct(ober)?|Nov(ember)?|Dec(ember)?)(,)?\s+\d{4}',
27 | r'\d{1,2}/\d{1,2}/\d{4}',
28 | r'\d{4}-\d{1,2}-\d{1,2}'
29 | ]
30 | combined_date_regex = '(?:%s)' % '|'.join(match_date_patterns)
31 |
32 | def save_to_processed(train, val, source_name, out_path):
33 | if not os.path.exists(out_path):
34 | os.makedirs(out_path)
35 | tf = os.path.join(out_path, f"train.{source_name}.jsonl")
36 | with open(tf, mode='w', encoding='utf-8') as out_file:
37 | for line in train:
38 | out_file.write(json.dumps(line) + "\n")
39 | print(f"Written {len(train)} documents to {tf}")
40 |
41 | vf = os.path.join(out_path, f"validation.{source_name}.jsonl")
42 | with open(vf, mode='w', encoding='utf-8') as out_file:
43 | for line in val:
44 | out_file.write(json.dumps(line) + "\n")
45 | print(f"Written {len(val)} documents to {vf}")
46 |
47 | def main():
48 | files = [os.path.join(RAW_PATH, f) for f in os.listdir(RAW_PATH)]
49 | print(f"{len(files)} total documents.")
50 | outputs = []
51 | outputs_tagged = []
52 | for file in tqdm(files):
53 | print("Processing:", file)
54 | with open(file, mode='r', encoding='utf-8') as in_file:
55 | text_tagged = in_file.read()
56 |
57 | # Remove annotation tags
58 | soup = BeautifulSoup(text_tagged, features="lxml")
59 | text = soup.get_text()
60 |
61 | # Extract creation date
62 | lines = text.splitlines()
63 | date_text = None
64 | for line in lines:
65 | date_text = re.search(combined_text_regex, line)
66 | # If matched text describing date, break at current line
67 | if date_text:
68 | break
69 | creation_date = ""
70 | if date_text:
71 | match = re.search(combined_date_regex, line)
72 | # If matched date, parse match
73 | if match:
74 | creation_date = parser.parse(match.group()).strftime("%m-%d-%Y")
75 |
76 | outputs.append({
77 | "url": "http://claudette.eui.eu/ToS.zip",
78 | "created_timestamp": creation_date,
79 | "downloaded_timestamp": datetime.date.today().strftime("%m-%d-%Y"),
80 | "text": text
81 | })
82 |
83 | outputs_tagged.append({
84 | "url": "http://claudette.eui.eu/ToS.zip",
85 | "created_timestamp": creation_date,
86 | "downloaded_timestamp": datetime.date.today().strftime("%m-%d-%Y"),
87 | "text": text_tagged
88 | })
89 |
90 | outputs = np.array(outputs)
91 | outputs_tagged = np.array(outputs_tagged)
92 |
93 | random.seed(0) # important for shuffling
94 | rand_idx = list(range(len(outputs)))
95 | random.shuffle(rand_idx)
96 |
97 | train_idx = rand_idx[:int(len(rand_idx)*0.75)]
98 | val_idx = rand_idx[int(len(rand_idx)*0.75):]
99 |
100 | # Same split for docs of original text and original text with annotation tags
101 | train = outputs[train_idx]
102 | val = outputs[val_idx]
103 |
104 | train_tagged = outputs_tagged[train_idx]
105 | val_tagged = outputs_tagged[val_idx]
106 |
107 | # Save train / val to processed
108 | save_to_processed(train, val, "tos", OUT_PATH)
109 |
110 | # Save train_tagged / val_tagged to processed_tagged
111 | save_to_processed(train_tagged, val_tagged, "tos_tagged", OUT_PATH_TAGGED)
112 |
113 | if __name__ == "__main__":
114 | main()
--------------------------------------------------------------------------------
/dataset_creation/un_debates/scrape_un_debates.py:
--------------------------------------------------------------------------------
1 | import os
2 | import zipfile
3 | import os
4 | import textract
5 | import random
6 | import bs4
7 | import json
8 | import datetime
9 | try:
10 | import lzma as xz
11 | except ImportError:
12 | import pylzma as xz
13 | # TODO: can't actually get have to downlaod manually
14 | url = "https://dataverse.harvard.edu/file.xhtml?fileId=4590189&version=6.0#"
15 |
16 | cache_path = "./cache/UNGDC_1970-2020.tar.gz"
17 | if not os.path.exists(cache_path):
18 | raise ValueError(f"Need to download from {url}")
19 | # import wget
20 | # filename = wget.download(url, out="./cache/")
21 | import tarfile
22 | overwrite = True
23 | open_type = 'w' if overwrite else 'a'
24 | train_f = xz.open("./cache/train.undebates.xz", open_type)
25 | val_f = xz.open("./cache/validation.undebates.xz", open_type)
26 | # if fname.endswith("tar.gz"):
27 | tar = tarfile.open(cache_path, "r:gz")
28 | tar.extractall()
29 | tar.close()
30 | for path, subdirs, files in os.walk("./TXT/"):
31 | for name in files:
32 | if not name.endswith(".txt") or name[0] == ".":
33 | continue
34 |
35 | with open(os.path.join(path, name), "r") as f:
36 | text = str(f.read())
37 |
38 | datapoint = {
39 | "text" : text,
40 | "created_timestamp" : name.split("_")[-1].replace(".txt", ""),
41 | "downloaded_timestamp" : datetime.date.today().strftime("%m-%d-%Y"),
42 | "url" : url
43 | }
44 |
45 | if random.random() > .75:
46 | val_f.write((json.dumps(datapoint) + "\n").encode("utf-8"))
47 | else:
48 | train_f.write((json.dumps(datapoint) + "\n").encode("utf-8"))
49 |
--------------------------------------------------------------------------------
/dataset_creation/us_bills/process_us_bills.py:
--------------------------------------------------------------------------------
1 | # Processes US congressional bills from the 108th - 117th Congress
2 | # Bulk data pulled from: https://www.govinfo.gov/bulkdata/xml/BILLSTATUS
3 | # Requests to bulk data API endpoint are prone to request errors,
4 | # use run.sh to call this script for each session for greater fault tolerance
5 |
6 | import os
7 | import sys
8 | import json
9 | import datetime
10 | import random
11 | from urllib.request import Request, urlopen
12 | # xpath only available in lxml etree, not ElementTree
13 | from lxml import etree
14 | from tqdm import tqdm
15 | from dateutil import parser
16 | import time
17 | import argparse
18 |
19 | # BASE URL (bulk data API endpoint)
20 | BASE_URL = "https://www.govinfo.gov/bulkdata/xml/BILLSTATUS"
21 |
22 | # DATA DIRS
23 | SAVE_DIR = "../../data/us_bills/save" # path to save per session data
24 |
25 | # Request variables
26 | headers = {'User-Agent': 'Mozilla/5.0', 'Accept': 'application/xml'}
27 |
28 |
29 | def save_to_file(data, out_dir, fname):
30 | if not os.path.exists(out_dir):
31 | os.makedirs(out_dir)
32 | fpath = os.path.join(out_dir, fname)
33 | with open(fpath, 'w') as out_file:
34 | for x in data:
35 | out_file.write(json.dumps(x) + "\n")
36 | print(f"Written {len(data)} to {fpath}")
37 |
38 |
39 | def request_status_xmls(session):
40 | urls = [os.path.join(BASE_URL, str(session))]
41 | status_xmls = []
42 | while len(urls) > 0:
43 | next_urls = []
44 | for url in urls:
45 | print(url)
46 | request = Request(url, headers=headers)
47 | time.sleep(random.uniform(0.02, 0.05))
48 | with urlopen(request) as response:
49 | root = etree.fromstring(response.read())
50 | elems = root.xpath("*/file[folder='true' and name!='resources']")
51 | if len(elems) > 0:
52 | for e in elems:
53 | next_url = e.find("link").text
54 | next_urls.append(next_url)
55 | else:
56 | elems = root.xpath("*/file[mimeType='application/xml']")
57 | for e in elems:
58 | xml_url = e.find("link").text
59 | # print(xml_url)
60 | request = Request(xml_url, headers=headers)
61 | for i in range(3): # retry request max of 3 times
62 | try:
63 | time.sleep(random.uniform(0.02, 0.05))
64 | with urlopen(request) as response:
65 | xml = etree.fromstring(response.read())
66 | # Add xml for bill status
67 | status_xmls.append(xml)
68 | break
69 | except:
70 | print("Retrying")
71 |
72 | urls = next_urls
73 |
74 | return status_xmls
75 |
76 |
77 | def request_raw_data(status_xmls):
78 | xmls = []
79 | for status_xml in tqdm(status_xmls):
80 | # print(status_xml)
81 | # Text versions are sorted in date order, find returns first item, which is most recent version
82 | text_info = status_xml.find("bill/textVersions/item")
83 | if text_info is not None:
84 | try:
85 | date = text_info.find("date").text
86 | except:
87 | date = ""
88 | try:
89 | xml_url = text_info.find("formats/item/url").text
90 | except:
91 | xml_url = None
92 | if xml_url:
93 | request = Request(xml_url, headers=headers)
94 | for i in range(3):
95 | try:
96 | time.sleep(random.uniform(0.02, 0.05))
97 | with urlopen(request) as response:
98 | xml = etree.fromstring(response.read())
99 | # print(date, xml_url)
100 | # Add tuple of (date, xml_url, xml) for raw bill text
101 | xmls.append((date, xml_url, xml))
102 | break
103 | except:
104 | print("Retrying")
105 |
106 | return xmls
107 |
108 |
109 | def prepare_docs(xmls):
110 | docs = []
111 | for (date, xml_url, xml) in tqdm(xmls):
112 | try:
113 | creation_date = parser.parse(date).strftime("%m-%d-%Y")
114 | except:
115 | creation_date = ""
116 |
117 | # In Python 3, use encoding='unicode'
118 | # In Python 2, use encoding='utf-8' and decode
119 | all_text = etree.tostring(xml, encoding='unicode', method='text')
120 |
121 | doc = {
122 | "url": xml_url,
123 | "created_timestamp": creation_date,
124 | "downloaded_timestamp": datetime.date.today().strftime("%m-%d-%Y"),
125 | "text": all_text
126 | }
127 | docs.append(doc)
128 |
129 | return docs
130 |
131 |
132 | def main():
133 | args = arg_parser.parse_args()
134 |
135 | # Request raw data directly using bulk data API
136 | print("Request status xmls")
137 | status_xmls = request_status_xmls(session=args.session)
138 | print("Request raw data")
139 | xmls = request_raw_data(status_xmls)
140 | print("Prepare docs")
141 | docs = prepare_docs(xmls)
142 |
143 | save_to_file(docs, SAVE_DIR, str(args.session) + ".us_bills.jsonl")
144 |
145 |
146 | if __name__ == '__main__':
147 | arg_parser = argparse.ArgumentParser()
148 | arg_parser.add_argument('--session', type=int, default=117)
149 | main()
--------------------------------------------------------------------------------
/dataset_creation/us_bills/run.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 | python3 process_us_bills.py --session=117
3 | python3 process_us_bills.py --session=116
4 | python3 process_us_bills.py --session=115
5 | python3 process_us_bills.py --session=114
6 | python3 process_us_bills.py --session=113
7 | python3 process_us_bills.py --session=112
8 | python3 process_us_bills.py --session=111
9 | python3 process_us_bills.py --session=110
10 | python3 process_us_bills.py --session=109
11 | python3 process_us_bills.py --session=108
--------------------------------------------------------------------------------
/dataset_creation/us_bills/split_us_bills.py:
--------------------------------------------------------------------------------
1 | # Loads data files saved by session (108th-117th Congress) from the bulk data API endpoints
2 | # Splits all data into train / test data
3 |
4 | import os
5 | import glob
6 | import json
7 | import random
8 | from process_us_bills import save_to_file
9 |
10 | # DATA DIRS
11 | SAVE_DIR = "../../data/us_bills/save" # path to save per session data
12 | OUT_DIR = "../../data/us_bills/processed" # path to write data out to
13 |
14 | def main():
15 | docs = []
16 | # Load all sessions from saved per-session data files
17 | session_files = glob.glob(os.path.join(SAVE_DIR, "*.jsonl"))
18 | for session_file in session_files:
19 | print(session_file)
20 | with open(session_file, 'r') as file:
21 | for line in file:
22 | docs.append(json.loads(line))
23 |
24 | # Shuffle and split into train / validation
25 | random.seed(0)
26 | random.shuffle(docs)
27 | train = docs[:int(len(docs)*0.75)]
28 | validation = docs[int(len(docs)*0.75):]
29 |
30 | print("Write train data")
31 | save_to_file(train, OUT_DIR, "train.us_bills.jsonl")
32 | print("Write validation data")
33 | save_to_file(validation, OUT_DIR, "validation.us_bills.jsonl")
34 |
35 |
36 | if __name__ == '__main__':
37 | main()
--------------------------------------------------------------------------------
/dataset_creation/uscode/scrape_us_code.py:
--------------------------------------------------------------------------------
1 | url = "https://uscode.house.gov/download/releasepoints/us/pl/117/49/xml_uscAll@117-49.zip"
2 |
3 | import zipfile
4 | import os
5 | import textract
6 | import random
7 | import bs4
8 | import json
9 | import datetime
10 | try:
11 | import lzma as xz
12 | except ImportError:
13 | import pylzma as xz
14 |
15 | cache_path = "./cache/xml_uscAll@117-49.zip"
16 | if not os.path.exists(cache_path):
17 | import wget
18 | filename = wget.download(url, out="./cache/")
19 |
20 | with zipfile.ZipFile(cache_path, 'r') as zip_ref:
21 | zip_ref.extractall("./cache/")
22 |
23 | overwrite = True
24 | open_type = 'w' if overwrite else 'a'
25 | train_f = xz.open("./cache/train.uscode.xz", open_type)
26 | val_f = xz.open("./cache/validation.uscode.xz", open_type)
27 | def html2text(x):
28 | soup = bs4.BeautifulSoup(x, "lxml")
29 | return soup.get_text()
30 |
31 | docs = []
32 | for path, subdirs, files in os.walk("./cache/"):
33 | for name in files:
34 | if not name.endswith(".xml"):
35 | continue
36 | with open(os.path.join(path, name), "r") as text_file:
37 | text = html2text(text_file.read())
38 | # text = str(textract.process(os.path.join(path, name)))
39 | datapoint = {
40 | "text" : text,
41 | "created_timestamp" : "2021",
42 | "downloaded_timestamp" : datetime.date.today().strftime("%m-%d-%Y"),
43 | "url" : url
44 | }
45 |
46 | if random.random() > .75:
47 | val_f.write((json.dumps(datapoint) + "\n").encode("utf-8"))
48 | else:
49 | train_f.write((json.dumps(datapoint) + "\n").encode("utf-8"))
--------------------------------------------------------------------------------
/pretraining/README.md:
--------------------------------------------------------------------------------
1 |
2 | # Vocab
3 |
4 | We use new_vocab.py for generating part of the vocabulary and the rest is supplemented via a random sampling of Black's law dictionary.
5 |
6 | # Chunkification and Randomization
7 |
8 | We use the chunkify_and_hd5.py file to segment and shuffle the data into shards. Note that each shard consists of a homogeneous sampling across data segments. We find that not doing this created extremely large instabilities.
9 |
10 | # Fine-tuning
11 |
12 | Fine-tuning code is in pol-finetuning-main
--------------------------------------------------------------------------------
/pretraining/new_vocab.py:
--------------------------------------------------------------------------------
1 | # !pip install tokenizers
2 | import os
3 | from tokenizers import BertWordPieceTokenizer
4 |
5 | fileList = [os.path.join("./unzipped_nojson", x) for x in os.listdir("./unzipped_nojson") if x is not None]
6 |
7 | # initialize
8 | tokenizer = BertWordPieceTokenizer(
9 | clean_text=False,
10 | handle_chinese_chars=True,
11 | strip_accents=True,
12 | lowercase=True
13 | )
14 |
15 | # and train
16 | tokenizer.train(files=fileList, vocab_size=29000, min_frequency=2,
17 | limit_alphabet=500, wordpieces_prefix='##',
18 | special_tokens=[
19 | '[PAD]', '[UNK]', '[CLS]', '[SEP]', '[MASK]'])
20 |
21 | tokenizer.save_model('./new_bert_vocab_2/', 'bert-wordpiece')
22 |
23 |
--------------------------------------------------------------------------------
/pretraining/pol-finetuning-main/README.md:
--------------------------------------------------------------------------------
1 | # pol-finetuning
2 |
3 | ## Setup
4 | - Create `pol_models` directory
5 | - Download Model4 (Models4 folder in GDrive) and pol-roberta weights from the Google Drive to the following paths in `pol_models`
6 | ```
7 | pol_models
8 | Model4
9 | config.json
10 | pytorch_model.bin
11 | tokenizer_config.json
12 | vocab.txt
13 | ```
14 | - Conda env file: `environment.yml`
15 | - Originally run on GCP instance with
16 | 4 x NVIDIA Tesla A100, adjust batch size according to the machine you're running on
17 |
18 |
19 |
--------------------------------------------------------------------------------
/pretraining/pol-finetuning-main/environment.yml:
--------------------------------------------------------------------------------
1 | name: pol
2 | channels:
3 | - pytorch
4 | - conda-forge
5 | - defaults
6 | dependencies:
7 | - _libgcc_mutex=0.1
8 | - _pytorch_select=0.1=cpu_0
9 | - absl-py=1.0.0=pyhd8ed1ab_0
10 | - blas=1.0=mkl
11 | - c-ares==1.18.1
12 | - ca-certificates=2021.10.8
13 | - cudatoolkit
14 | - freetype=2.10.4
15 | - giflib=5.2.1
16 | - grpcio==1.46.1
17 | - importlib-metadata==4.11.3
18 | - intel-openmp
19 | - jbig==2.1
20 | - jpeg==9e
21 | - lcms2=2.12
22 | - ld_impl_linux-64==2.36.1
23 | - lerc=3.0
24 | - libblas=3.9.0
25 | - libcblas=3.9.0
26 | - libdeflate=1.10
27 | - libffi=3.4.2
28 | - libgfortran-ng
29 | - libgfortran5
30 | - liblapack==3.9.0
31 | - libopenblas=0.3.20
32 | - libpng=1.6.37
33 | - libprotobuf==3.20.1
34 | - libtiff==4.3.0
35 | - libuv==1.43.0
36 | - libwebp==1.2.2
37 | - libwebp-base==1.2.2
38 | - libxcb=1.13
39 | - libzlib=1.2.11
40 | - lz4-c=1.9.3
41 | - markdown=3.3.7=pyhd8ed1ab_0
42 | - mkl
43 | - ncurses=6.3
44 | - ninja=1.10.2
45 | - numpy=1.21.6
46 | - openjpeg==2.4.0
47 | - openssl=3.0.3
48 | - pillow==9.1.0
49 | - pip=22.1
50 | - pthread-stubs==0.4
51 | - python=3.7.12
52 | - python_abi=3.7=2_cp37m
53 | - pytorch=1.7.1
54 | - readline=8.1
55 | - setuptools==62.2.0
56 | - six=1.16.0=pyh6c4a22f_0
57 | - sqlite
58 | - tensorboard=1.15.0=py37_0
59 | - tk==8.6.12
60 | - torchaudio=0.7.2=py37
61 | - torchvision=0.8.2
62 | - typing_extensions=4.2.0=pyha770c72_1
63 | - werkzeug=2.1.2=pyhd8ed1ab_1
64 | - wheel=0.37.1=pyhd8ed1ab_0
65 | - xorg-libxau==1.0.9
66 | - xorg-libxdmcp=1.1.3
67 | - xz==5.2.5
68 | - zipp=3.8.0=pyhd8ed1ab_0
69 | - zlib==1.2.11
70 | - zstd==1.5.2
71 | - pip:
72 | - aiohttp==3.8.1
73 | - aiosignal==1.2.0
74 | - apex==0.1
75 | - async-timeout==4.0.2
76 | - asynctest==0.13.0
77 | - attrs==21.4.0
78 | - certifi==2021.10.8
79 | - charset-normalizer==2.0.12
80 | - click==8.1.3
81 | - datasets==2.1.0
82 | - dill==0.3.4
83 | - docker-pycreds==0.4.0
84 | - filelock==3.6.0
85 | - frozenlist==1.3.0
86 | - fsspec==2022.3.0
87 | - gitdb==4.0.9
88 | - gitpython==3.1.27
89 | - huggingface-hub==0.5.1
90 | - idna==3.3
91 | - joblib==1.1.0
92 | - multidict==6.0.2
93 | - multiprocess==0.70.12.2
94 | - nltk==3.7
95 | - packaging==21.3
96 | - pandas==1.3.5
97 | - pathtools==0.1.2
98 | - promise==2.3
99 | - protobuf==3.20.1
100 | - psutil==5.9.0
101 | - pyarrow==7.0.0
102 | - pyparsing==3.0.8
103 | - python-dateutil==2.8.2
104 | - pytz==2022.1
105 | - pyyaml==6.0
106 | - regex==2022.4.24
107 | - requests==2.27.1
108 | - responses==0.18.0
109 | - sacremoses==0.0.49
110 | - scikit-learn==1.0.2
111 | - scipy==1.7.3
112 | - sentry-sdk==1.5.10
113 | - setproctitle==1.2.3
114 | - shortuuid==1.0.8
115 | - smmap==5.0.0
116 | - threadpoolctl==3.1.0
117 | - tokenizers==0.12.1
118 | - tqdm==4.64.0
119 | - transformers==4.18.0
120 | - urllib3==1.26.9
121 | - wandb==0.12.15
122 | - xxhash==3.0.0
123 | - yarl==1.7.2
124 | - libgcc-ng
125 | - libgomp
126 | - libnsl
127 | - libstdcxx-ng
128 | prefix: /opt/conda/envs/pol
129 |
--------------------------------------------------------------------------------
/pretraining/pol-finetuning-main/scripts/run_casehold.sh:
--------------------------------------------------------------------------------
1 | export WANDB_PROJECT=pol
2 | TASK='case_hold'
3 | SEED=1
4 |
5 |
6 | MODEL='roberta-base'
7 | MODEL_NAME_OR_PATH=${MODEL}
8 |
9 |
10 |
11 | python tasks/casehold.py \
12 | --model_name_or_path ${MODEL_NAME_OR_PATH} \
13 | --task_name ${TASK} \
14 | --do_train \
15 | --do_eval \
16 | --evaluation_strategy='steps' \
17 | --eval_steps 500 \
18 | --save_strategy='steps' \
19 | --save_steps 0 \
20 | --logging_steps 500 \
21 | --max_seq_length 512 \
22 | --per_device_train_batch_size=2 \
23 | --per_device_eval_batch_size=64 \
24 | --learning_rate=1e-5 \
25 | --num_train_epochs=7 \
26 | --seed ${SEED} \
27 | --fp16 \
28 | --report_to=wandb \
29 | --run_name=${TASK}/${MODEL}/seed_${SEED} \
30 | --output_dir=logs/${TASK}/${MODEL}/lr=1e-5/seed_${SEED} \
31 | --overwrite_output_dir
--------------------------------------------------------------------------------
/pretraining/pol-finetuning-main/tasks/adam_bias_correct.py:
--------------------------------------------------------------------------------
1 | import math
2 | from typing import Callable, Iterable, Tuple
3 |
4 | import torch
5 | from torch import nn
6 |
7 | from transformers.optimization import AdamW
8 |
9 | class AdamWBC(AdamW):
10 | """
11 | Implements Adam algorithm with bias correction for BERT, helps with training stability for large models
12 | Parameters:
13 | params (`Iterable[nn.parameter.Parameter]`):
14 | Iterable of parameters to optimize or dictionaries defining parameter groups.
15 | lr (`float`, *optional*, defaults to 1e-3):
16 | The learning rate to use.
17 | betas (`Tuple[float,float]`, *optional*, defaults to (0.9, 0.999)):
18 | Adam's betas parameters (b1, b2).
19 | eps (`float`, *optional*, defaults to 1e-6):
20 | Adam's epsilon for numerical stability.
21 | weight_decay (`float`, *optional*, defaults to 0):
22 | Decoupled weight decay to apply.
23 | correct_bias (`bool`, *optional*, defaults to `True`):
24 | Whether or not to correct bias in Adam (for instance, in Bert TF repository they use `False`).
25 | no_deprecation_warning (`bool`, *optional*, defaults to `False`):
26 | A flag used to disable the deprecation warning (set to `True` to disable the warning).
27 | """
28 |
29 | def __init__(
30 | self,
31 | params: Iterable[nn.parameter.Parameter],
32 | lr: float = 1e-3,
33 | betas: Tuple[float, float] = (0.9, 0.999),
34 | eps: float = 1e-6,
35 | weight_decay: float = 0.0,
36 | correct_bias: bool = True,
37 | no_deprecation_warning: bool = False,
38 | ):
39 | super().__init__(params, lr, betas, eps, weight_decay, correct_bias, no_deprecation_warning)
40 |
41 | def step(self, closure: Callable = None):
42 | """
43 | Performs a single optimization step.
44 | Arguments:
45 | closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss.
46 | """
47 | loss = None
48 | if closure is not None:
49 | loss = closure()
50 |
51 | for group in self.param_groups:
52 | for p in group["params"]:
53 | if p.grad is None:
54 | continue
55 | grad = p.grad.data
56 | if grad.is_sparse:
57 | raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead")
58 |
59 | state = self.state[p]
60 |
61 | # State initialization
62 | if len(state) == 0:
63 | state["step"] = 0
64 | # Exponential moving average of gradient values
65 | state["exp_avg"] = torch.zeros_like(p.data)
66 | # Exponential moving average of squared gradient values
67 | state["exp_avg_sq"] = torch.zeros_like(p.data)
68 |
69 | exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
70 | beta1, beta2 = group["betas"]
71 |
72 | state["step"] += 1
73 |
74 | # Decay the first and second moment running average coefficient
75 | # In-place operations to update the averages at the same time
76 | exp_avg.mul_(beta1).add_(grad, alpha=(1.0 - beta1))
77 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
78 | denom = exp_avg_sq.sqrt().add_(group["eps"])
79 |
80 | step_size = group["lr"]
81 | # Bias correction for Bert
82 | bias_correction1 = 1.0 - beta1 ** state["step"]
83 | bias_correction2 = 1.0 - beta2 ** state["step"]
84 | step_size = step_size * math.sqrt(bias_correction2) / bias_correction1
85 |
86 | p.data.addcdiv_(exp_avg, denom, value=-step_size)
87 |
88 | # Just adding the square of the weights to the loss function is *not*
89 | # the correct way of using L2 regularization/weight decay with Adam,
90 | # since that will interact with the m and v parameters in strange ways.
91 | #
92 | # Instead we want to decay the weights in a manner that doesn't interact
93 | # with the m/v parameters. This is equivalent to adding the square
94 | # of the weights to the loss with plain (non-momentum) SGD.
95 | # Add weight decay at the end (fixed version)
96 | if group["weight_decay"] > 0.0:
97 | p.data.add_(p.data, alpha=(-group["lr"] * group["weight_decay"]))
98 |
99 | return loss
--------------------------------------------------------------------------------
/pretraining/pol-finetuning-main/tasks/casehold_helpers.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | from dataclasses import dataclass
4 | from enum import Enum
5 | from typing import List, Optional
6 |
7 | import tqdm
8 | import re
9 |
10 | from filelock import FileLock
11 | from transformers import PreTrainedTokenizer, is_tf_available, is_torch_available
12 | import datasets
13 |
14 | logger = logging.getLogger(__name__)
15 |
16 |
17 | @dataclass(frozen=True)
18 | class InputFeatures:
19 | """
20 | A single set of features of data.
21 | Property names are the same names as the corresponding inputs to a model.
22 | """
23 |
24 | input_ids: List[List[int]]
25 | attention_mask: Optional[List[List[int]]]
26 | token_type_ids: Optional[List[List[int]]]
27 | label: Optional[int]
28 |
29 |
30 | class Split(Enum):
31 | train = "train"
32 | dev = "dev"
33 | test = "test"
34 |
35 |
36 | if is_torch_available():
37 | import torch
38 | from torch.utils.data.dataset import Dataset
39 |
40 | class MultipleChoiceDataset(Dataset):
41 | """
42 | PyTorch multiple choice dataset class
43 | """
44 |
45 | features: List[InputFeatures]
46 |
47 | def __init__(
48 | self,
49 | tokenizer: PreTrainedTokenizer,
50 | task: str,
51 | max_seq_length: Optional[int] = None,
52 | overwrite_cache=False,
53 | mode: Split = Split.train,
54 | ):
55 | dataset = datasets.load_dataset('lex_glue', task)
56 | tokenizer_name = re.sub('[^a-z]+', ' ', tokenizer.name_or_path).title().replace(' ', '')
57 | cached_features_file = os.path.join(
58 | '.cache',
59 | task,
60 | "cached_{}_{}_{}_{}".format(
61 | mode.value,
62 | tokenizer_name,
63 | str(max_seq_length),
64 | task,
65 | ),
66 | )
67 |
68 | # Make sure only the first process in distributed training processes the dataset,
69 | # and the others will use the cache.
70 | lock_path = cached_features_file + ".lock"
71 | if not os.path.exists(os.path.join('.cache', task)):
72 | if not os.path.exists('.cache'):
73 | os.mkdir('.cache')
74 | os.mkdir(os.path.join('.cache', task))
75 | with FileLock(lock_path):
76 |
77 | if os.path.exists(cached_features_file) and not overwrite_cache:
78 | logger.info(f"Loading features from cached file {cached_features_file}")
79 | self.features = torch.load(cached_features_file)
80 | else:
81 | logger.info(f"Creating features from dataset file at {task}")
82 | if mode == Split.dev:
83 | examples = dataset['validation']
84 | elif mode == Split.test:
85 | examples = dataset['test']
86 | elif mode == Split.train:
87 | examples = dataset['train']
88 | logger.info("Training examples: %s", len(examples))
89 | self.features = convert_examples_to_features(
90 | examples,
91 | max_seq_length,
92 | tokenizer,
93 | )
94 | logger.info("Saving features into cached file %s", cached_features_file)
95 | torch.save(self.features, cached_features_file)
96 |
97 | def __len__(self):
98 | return len(self.features)
99 |
100 | def __getitem__(self, i) -> InputFeatures:
101 | return self.features[i]
102 |
103 |
104 | if is_tf_available():
105 | import tensorflow as tf
106 |
107 | class TFMultipleChoiceDataset:
108 | """
109 | TensorFlow multiple choice dataset class
110 | """
111 |
112 | features: List[InputFeatures]
113 |
114 | def __init__(
115 | self,
116 | tokenizer: PreTrainedTokenizer,
117 | task: str,
118 | max_seq_length: Optional[int] = 256,
119 | overwrite_cache=False,
120 | mode: Split = Split.train,
121 | ):
122 | dataset = datasets.load_dataset('lex_glue')
123 |
124 | logger.info(f"Creating features from dataset file at {task}")
125 | if mode == Split.dev:
126 | examples = dataset['validation']
127 | elif mode == Split.test:
128 | examples = dataset['test']
129 | else:
130 | examples = dataset['train']
131 | logger.info("Training examples: %s", len(examples))
132 |
133 | self.features = convert_examples_to_features(
134 | examples,
135 | max_seq_length,
136 | tokenizer,
137 | )
138 |
139 | def gen():
140 | for (ex_index, ex) in tqdm.tqdm(enumerate(self.features), desc="convert examples to features"):
141 | if ex_index % 10000 == 0:
142 | logger.info("Writing example %d of %d" % (ex_index, len(examples)))
143 |
144 | yield (
145 | {
146 | "input_ids": ex.input_ids,
147 | "attention_mask": ex.attention_mask,
148 | "token_type_ids": ex.token_type_ids,
149 | },
150 | ex.label,
151 | )
152 |
153 | self.dataset = tf.data.Dataset.from_generator(
154 | gen,
155 | (
156 | {
157 | "input_ids": tf.int32,
158 | "attention_mask": tf.int32,
159 | "token_type_ids": tf.int32,
160 | },
161 | tf.int64,
162 | ),
163 | (
164 | {
165 | "input_ids": tf.TensorShape([None, None]),
166 | "attention_mask": tf.TensorShape([None, None]),
167 | "token_type_ids": tf.TensorShape([None, None]),
168 | },
169 | tf.TensorShape([]),
170 | ),
171 | )
172 |
173 | def get_dataset(self):
174 | self.dataset = self.dataset.apply(tf.data.experimental.assert_cardinality(len(self.features)))
175 |
176 | return self.dataset
177 |
178 | def __len__(self):
179 | return len(self.features)
180 |
181 | def __getitem__(self, i) -> InputFeatures:
182 | return self.features[i]
183 |
184 |
185 | def convert_examples_to_features(
186 | examples: datasets.Dataset,
187 | max_length: int,
188 | tokenizer: PreTrainedTokenizer,
189 | ) -> List[InputFeatures]:
190 | """
191 | Loads a data file into a list of `InputFeatures`
192 | """
193 | features = []
194 | for (ex_index, example) in tqdm.tqdm(enumerate(examples), desc="convert examples to features"):
195 | if ex_index % 10000 == 0:
196 | logger.info("Writing example %d of %d" % (ex_index, len(examples)))
197 | choices_inputs = []
198 | for ending_idx, ending in enumerate(example['endings']):
199 | context = example['context']
200 | inputs = tokenizer(
201 | context,
202 | ending,
203 | add_special_tokens=True,
204 | max_length=max_length,
205 | padding="max_length",
206 | truncation=True,
207 | )
208 |
209 | choices_inputs.append(inputs)
210 |
211 | label = example['label']
212 |
213 | input_ids = [x["input_ids"] for x in choices_inputs]
214 | attention_mask = (
215 | [x["attention_mask"] for x in choices_inputs] if "attention_mask" in choices_inputs[0] else None
216 | )
217 | token_type_ids = (
218 | [x["token_type_ids"] for x in choices_inputs] if "token_type_ids" in choices_inputs[0] else None
219 | )
220 |
221 | features.append(
222 | InputFeatures(
223 | input_ids=input_ids,
224 | attention_mask=attention_mask,
225 | token_type_ids=token_type_ids,
226 | label=label,
227 | )
228 | )
229 |
230 | for f in features[:2]:
231 | logger.info("*** Example ***")
232 | logger.info("feature: %s" % f)
233 |
234 | return features
235 |
--------------------------------------------------------------------------------
/privacy/README.md:
--------------------------------------------------------------------------------
1 | # Privacy Experiments
2 |
3 | ## EOIR
4 |
5 | The EOIR experiments and links to data/models is available in the eoir folder.
6 |
7 |
8 | ## Jane Doe
9 |
10 | The Jane Doe experiments can be found in the Jane Doe Folder.
11 |
12 |
13 |
--------------------------------------------------------------------------------
/privacy/eoir/README.md:
--------------------------------------------------------------------------------
1 |
2 | # EOIR Pseudonyms Experiment
3 |
4 | In this experiment we first create a dataset of paragraphs from:
5 |
6 | ```
7 | python create_pseodonyms_dataset.py
8 | ```
9 |
10 | We upload the generated dataset linking data to labels in [this HuggingFace Repo](https://huggingface.co/datasets/pile-of-law/eoir_privacy).
11 |
12 | Then we use a Colab notebook to train a distillbert model, also available as an ipython notebook:
13 |
14 | ```
15 | EOIR.ipynb
16 | ```
17 |
18 | The resulting model is a runable model which we [also upload to HF](https://huggingface.co/pile-of-law/distilbert-base-uncased-finetuned-eoir_privacy).
19 |
20 | Then we run a perturbation experiment via the script as seen in
21 |
22 | ```
23 | EOIR_validation_exp.ipynb
24 | ```
25 |
26 | For the causal lexicon experiment we use:
27 |
28 | ```
29 | causal_exp.py
30 | ```
31 |
32 |
--------------------------------------------------------------------------------
/privacy/eoir/causal_exp.py:
--------------------------------------------------------------------------------
1 | import json
2 | from datasets import load_dataset
3 | import spacy
4 | from citeurl import Citator
5 | citator = Citator()
6 |
7 | from collections import Counter
8 | def is_number(s):
9 | try:
10 | float(s)
11 | return True
12 | except ValueError:
13 | return False
14 |
15 | def build_vocab(texts, max_vocab=10000, min_freq=3):
16 | nlp = spacy.load("en_core_web_sm") # just the tokenizer
17 | wc = Counter()
18 | for doc in nlp.pipe(texts):
19 | citations = citator.list_cites(doc.text)
20 | for x in citations:
21 | wc[str(x).lower()] += 1
22 | for word in doc:
23 | if word.ent_type_ == "CARDINAL":
24 | continue
25 | if is_number(word.text):
26 | continue
27 | if len(word.text.strip()) <= 2:
28 | continue
29 | wc[word.lower_] += 1
30 |
31 | word2id = {}
32 | id2word = {}
33 | for word, count in wc.most_common():
34 | if count < min_freq: break
35 | if len(word2id) >= max_vocab: break
36 | wid = len(word2id)
37 | word2id[word] = wid
38 | id2word[wid] = word
39 | return list([x for x in word2id.keys()])
40 |
41 | process = True
42 | if process:
43 | dataset = load_dataset("pile-of-law/eoir_privacy", "all", split="all", use_auth_token=True)
44 |
45 | data = {}
46 |
47 | dataset.to_csv("descriptions.csv")
48 |
49 | import pandas as pd
50 | x = pd.read_csv("descriptions.csv")
51 | pseudo_label = { 0 : "no_pseudo", 1 : "pseudo"}
52 | x["label"] = [pseudo_label[z] for z in x["label"]]
53 | x = x.drop_duplicates("text")
54 | pres = pd.read_csv("presidents.csv")
55 |
56 | year_to_pres_map = {}
57 | for name, years in zip(pres["President Name"], pres["Years In Office"]):
58 | splitted_years = years.split("-")
59 | if splitted_years[-1].strip() == "":
60 | splitted_years = [splitted_years[0]]
61 | if len(splitted_years) > 1:
62 | for i in range(int(splitted_years[0]), int(splitted_years[1])+1):
63 | year_to_pres_map[i] = name
64 | else:
65 | year_to_pres_map[int(splitted_years[0])] = name
66 |
67 | x["president"] = [year_to_pres_map[int(year)] if is_number(year) and int(year) in year_to_pres_map.keys() else "UNK" for year in x["year"]]
68 | x = x[[pres != "UNK" for pres in x["president"]]]
69 | x = x[[True if is_number(year) else False for year in x["year"]]]
70 |
71 | x.to_csv("descriptions.csv")
72 |
73 | vocab = build_vocab(dataset["text"])
74 | with open("vocab.txt", "w") as f:
75 | for v in vocab:
76 | f.write(v + "\n")
77 |
78 | with open("vocab.txt", "r") as f:
79 | vocab = [x.strip() for x in f.readlines()]
80 |
81 | print("Finished vocab!")
82 | import causal_attribution
83 | merged_scores = {}
84 | for method in ["residualization"]:
85 | for hidden_size in [512]:
86 | for lr in [7e-4]:
87 | importance_scores = causal_attribution.score_vocab(
88 | vocab=vocab,
89 | scoring_model=method,
90 | hidden_size = hidden_size,
91 | lr=lr,
92 | train_steps = 3000,
93 | max_seq_len=750,
94 | status_bar=True,
95 | use_gpu=True,
96 | csv="descriptions.csv",
97 | delimiter=",",
98 | name_to_type={
99 | 'text': 'input',
100 | #'name' : 'control',
101 | #'president' : 'control',
102 | 'year' : 'control',
103 | 'label': 'predict'
104 | })
105 | for (key, value) in importance_scores["label"]["pseudo"]:
106 | if key not in merged_scores:
107 | merged_scores[key] = []
108 | merged_scores[key].append(value)
109 |
110 | averaged = {}
111 | import numpy as np
112 | for key, val in merged_scores.items():
113 | averaged[key] = np.mean(val)
114 |
115 | with open("averaged_results.json", "w") as f:
116 | f.write(json.dumps(averaged))
117 | import pdb; pdb.set_trace()
118 |
--------------------------------------------------------------------------------
/privacy/eoir/create_pseudonyms_dataset.py:
--------------------------------------------------------------------------------
1 | import json
2 | import syntok.segmenter as segmenter
3 | import re
4 | import random
5 | try:
6 | import lzma as xz
7 | except ImportError:
8 | import pylzma as xz
9 |
10 | with open('../../dataset_creation/eoir/cache/train.eoir_pseudonym.jsonl', 'r') as f:
11 | data = [json.loads(x) for x in f.readlines()]
12 |
13 | import spacy
14 | from collections import Counter
15 |
16 | import re
17 | nlp = spacy.load("en_core_web_sm")
18 | replacement = re.compile(r"(the )?(respondent|applicant|defendant|plaintiff|petitioner)(s)?", re.IGNORECASE)
19 | def paragraphs(document):
20 | start = 0
21 | for token in document:
22 | if token.is_space and token.text.count("\n") > 1:
23 | yield document[start:token.i]
24 | start = token.i
25 | yield document[start:]
26 |
27 | with xz.open("./train.privacy.eoir.jsonl.xz", "wt") as f1:
28 | with xz.open("./validation.privacy.eoir.jsonl.xz", "wt") as f2:
29 | for datapoint in data:
30 | doc = nlp(datapoint["text"])
31 | for paragraph in paragraphs(doc):
32 | if len(paragraph.text) < 700:
33 | continue
34 | if ", Respondent" in paragraph.text:
35 | continue
36 | # TODO: better paragraph splitting
37 | text = paragraph.text
38 | text = text.replace("\n", " ")
39 |
40 | for ent in paragraph.ents:
41 | if ent.label_ == "PERSON":
42 | if not ("v." in text[max(0, ent.start-5):min(ent.end + 5, len(text))].lower()) or ("in re" in text[max(0, ent.start-5):min(ent.end + 5, len(text))].lower()):
43 | swapped_para = replacement.sub("[MASK]", text)
44 |
45 | if "[MASK]" not in swapped_para:
46 | continue
47 |
48 | print(datapoint["name"])
49 | new_data = {
50 | "text" : swapped_para,
51 | "label" : datapoint["is_pseudonym"],
52 | "year" : datapoint["created_timestamp"].split(".")[-1],
53 | "name" : datapoint["name"]
54 | }
55 |
56 | if random.random() < .15:
57 | f2.write(json.dumps(new_data) + "\n")
58 | else:
59 | f1.write(json.dumps(new_data) + "\n")
--------------------------------------------------------------------------------
/privacy/janedoe/README.md:
--------------------------------------------------------------------------------
1 | # Jane Doe Privacy Experiment
2 |
3 | To run the positive example experiment run jane_doe.py. Data is linked in the associated code.
4 |
5 | To run the negative example experiment (where Jane Doe pseudonym is not needed), run jane_doe_negative.py
6 |
7 | Note: you have to change the model name inline to switch between bert and pol-bert.
8 |
9 |
10 |
--------------------------------------------------------------------------------
/privacy/janedoe/jane_doe.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import pandas as pd
3 | import mxnet as mx
4 | import re
5 | import names
6 | import numpy as np
7 | import json
8 | from mlm.scorers import MLMScorer, MLMScorerPT, LMScorer
9 | from mlm.models import get_pretrained
10 |
11 | import numpy as np
12 |
13 | sheet_id = "1zf2xfYJ0dvmSVFHUATvDXqG8TFX3Gmt1CppiWTyscg0"
14 | sheet_name = "Sheet1"
15 | url = f"https://docs.google.com/spreadsheets/d/{sheet_id}/gviz/tq?tqx=out:csv&sheet={sheet_name}"
16 |
17 | df = pd.read_csv(url, on_bad_lines='skip')
18 |
19 | df["Text"].dropna()
20 |
21 |
22 |
23 | ctxs = [mx.gpu(0)] # or, e.g., [mx.gpu(0), mx.gpu(1)]
24 | model_name = 'pile-of-law/legalbert-large-1.7M-2'
25 | model, vocab, tokenizer = get_pretrained(ctxs, model_name)
26 | scorer = MLMScorerPT(model, vocab, tokenizer, ctxs)
27 |
28 | insensitive_hippo = re.compile(re.escape('jane doe'), re.IGNORECASE)
29 | insensitive_hippo2 = re.compile(re.escape(' doe '), re.IGNORECASE)
30 | insensitive_hippo3 = re.compile(re.escape('jane roe'), re.IGNORECASE)
31 | insensitive_hippo4 = re.compile(re.escape(' roe '), re.IGNORECASE)
32 |
33 | average_diff = []
34 |
35 | for sent in list(df["Text"].dropna()):
36 | replacement_sentences = [sent]
37 | for i in range(2):
38 | new_name = names.get_full_name(gender="female")
39 | new_sent = insensitive_hippo.sub(new_name, sent)
40 | new_sent = insensitive_hippo2.sub(" " + new_name + " ", new_sent)
41 | new_sent = insensitive_hippo3.sub(new_name, new_sent)
42 | new_sent = insensitive_hippo4.sub(" " + new_name + " ", new_sent)
43 | replacement_sentences.append(new_sent)
44 |
45 | torch.cuda.empty_cache()
46 | with torch.no_grad():
47 | scores = scorer.score_sentences(replacement_sentences, split_size=51)
48 | main_score = scores[0]
49 | for i, x in enumerate(scores):
50 | if i == 0: continue
51 | diff = main_score - x
52 | average_diff.append(diff)
53 | print(np.mean(average_diff))
54 |
55 | with open(f"jane_doe_diffs_{model_name.split('/')[-1]}.json", "w") as f:
56 | f.write(json.dumps(average_diff))
57 |
58 |
--------------------------------------------------------------------------------
/privacy/janedoe/jane_doe_negative.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import json
3 | import torch
4 | import pandas as pd
5 | from mlm.scorers import MLMScorer, MLMScorerPT, LMScorer
6 | from mlm.models import get_pretrained
7 | import pickle
8 | import numpy as np
9 |
10 | # To generate pickle, look to https://github.com/microsoft/biosbias
11 | with open("BIOS.pkl", "rb") as f:
12 | sentences = pickle.load(f)[:500]
13 |
14 | import mxnet as mx
15 | import re
16 | import names
17 | import numpy as np
18 |
19 | ctxs = [mx.gpu(0)] # or, e.g., [mx.gpu(0), mx.gpu(1)]
20 | model_name = 'pile-of-law/legalbert-large-1.7M-2' # Swap model name to use different model
21 | model, vocab, tokenizer = get_pretrained(ctxs, model_name)
22 | scorer = MLMScorerPT(model, vocab, tokenizer, ctxs)
23 |
24 | import re
25 |
26 |
27 | import names
28 | average_diff = []
29 |
30 | import random
31 | for sent in list(sentences):
32 | sent = sent["bio"]
33 |
34 | replacement_sentences = [sent.replace("_", "Jane Doe")]
35 |
36 | for i in range(2):
37 | new_name = names.get_full_name(gender="female")
38 | new_sent = sent.replace("_", new_name)
39 | replacement_sentences.append(new_sent)
40 |
41 | torch.cuda.empty_cache()
42 | with torch.no_grad():
43 | scores = scorer.score_sentences(replacement_sentences, split_size=51)
44 | main_score = scores[0]
45 | for i, x in enumerate(scores):
46 | if i == 0: continue
47 | diff = main_score - x
48 | average_diff.append(diff)
49 | print(np.mean(average_diff))
50 |
51 | with open(f"jane_doe_diffs_{model_name.split('/')[-1]}_negative.json", "w") as f:
52 | f.write(json.dumps(average_diff))
53 |
54 |
--------------------------------------------------------------------------------
/privacy/janedoe/jane_doe_plot.py:
--------------------------------------------------------------------------------
1 | import json
2 | import seaborn as sns
3 | import pandas as pd
4 |
5 | import matplotlib.pyplot as plt
6 |
7 | # this is the setting for plots for research paper and articles.
8 | # %matplotlib inline
9 | # %config InlineBackend.figure_format = 'retina'
10 |
11 | # Set the global font to be DejaVu Sans, size 10 (or any other sans-serif font of your choice!)
12 | plt.rc('font',**{'family':'sans-serif','sans-serif':['DejaVu Sans'],'size':16})
13 |
14 | # Set the font used for MathJax - more on this later
15 | plt.rc('mathtext',**{'default':'regular'})
16 |
17 | # Set the style for seaborn
18 | plt.style.use(['seaborn-whitegrid', 'seaborn-paper'])
19 |
20 | import matplotlib.pylab as pylab
21 | params = {'legend.fontsize': 'large',
22 | 'axes.labelsize': 'large',
23 | 'axes.titlesize': 'large',
24 | 'xtick.labelsize': 'large',
25 | 'ytick.labelsize': 'large'
26 | }
27 |
28 | pylab.rcParams.update(**params)
29 |
30 | import seaborn as sns
31 | sns.set_context(rc=params)
32 |
33 | def stylize_axes(ax, title):
34 | """
35 | Stylize the axes by removing ths spines and ticks.
36 | """
37 | # removes the top and right lines from the plot rectangle
38 | ax.spines['top'].set_visible(False)
39 | ax.spines['right'].set_visible(False)
40 |
41 | ax.xaxis.set_tick_params(top=False, direction='out', width=1)
42 | ax.yaxis.set_tick_params(right=False, direction='out', width=1)
43 |
44 | # Enforce the size of the title, label and tick labels
45 | ax.set_xlabel(ax.get_xlabel(), fontsize='large')
46 | ax.set_ylabel(ax.get_ylabel(), fontsize='large')
47 |
48 | ax.set_yticklabels(ax.get_yticklabels(), fontsize='medium')
49 | ax.set_xticklabels(ax.get_xticklabels(), fontsize='medium')
50 |
51 | ax.set_title(title, fontsize='large')
52 |
53 | def save_image(fig, title):
54 | """
55 | Save the figure as PNG and pdf files
56 | """
57 | if title is not None:
58 | fig.savefig(title+".png", dpi=300, bbox_inches='tight', transparent=True)
59 | fig.savefig(title+".pdf", bbox_inches='tight')
60 |
61 | def figure_size(fig, size):
62 | fig.set_size_inches(size)
63 | fig.tight_layout()
64 |
65 | def resadjust(ax, xres=None, yres=None):
66 | """
67 | Send in an axis and fix the resolution as desired.
68 | """
69 |
70 | if xres:
71 | start, stop = ax.get_xlim()
72 | ticks = np.arange(start, stop + xres, xres)
73 | ax.set_xticks(ticks)
74 | if yres:
75 | start, stop = ax.get_ylim()
76 | ticks = np.arange(start, stop + yres, yres)
77 | ax.set_yticks(ticks)
78 |
79 | with open("jane_doe_diffs_legalbert-large-1.7M-2.json", "r") as f:
80 | jane_doe_lb = json.load(f)
81 | with open("jane_doe_diffs_bert-large-uncased.json", "r") as f:
82 | jane_doe_b = json.load(f)
83 | with open("jane_doe_diffs_legalbert-large-1.7M-2_negative.json", "r") as f:
84 | jane_doe_lb_negjd = json.load(f)
85 | with open("jane_doe_diffs_bert-large-uncased_negative.json", "r") as f:
86 | jane_doe_b_negjd = json.load(f)
87 |
88 |
89 | import matplotlib.pyplot as plt
90 |
91 | df = pd.DataFrame.from_dict({
92 | "Model" : (["bert"] * len(jane_doe_b)) + (["pol-bert"] * len(jane_doe_lb)) + (["bert"] * len(jane_doe_b_negjd)) + (["pol-bert"] * len(jane_doe_lb_negjd)),
93 | "Jane Doe Score" : jane_doe_b + jane_doe_lb + jane_doe_b_negjd + jane_doe_lb_negjd,
94 | "Sample Type" : ["Court Case\n(Jane Doe)"] * (len(jane_doe_b) + len(jane_doe_lb)) + ["Bios"] * (len(jane_doe_b_negjd) + len(jane_doe_lb_negjd))
95 | })
96 |
97 | #swarm_plot = sns.barplot(x="Model", y="Jane Doe Score", data=df)
98 | swarm_plot = sns.factorplot(x = 'Sample Type', y='Jane Doe Score',
99 | hue = 'Model', data=df, kind='bar')
100 | #fig = swarm_plot.get_figure()
101 | plt.xticks(rotation=20)
102 |
103 | save_image(swarm_plot.fig, "janedoe")
104 | import numpy as np
105 | print(np.mean(jane_doe_b) - np.mean(jane_doe_b_negjd))
106 | print(np.mean(jane_doe_lb) - np.mean(jane_doe_lb_negjd))
--------------------------------------------------------------------------------
/scrub/README.md:
--------------------------------------------------------------------------------
1 | ## Scrub
2 | This directory contains the script, `scrub.py`, for scrubbing the dataset for sensitive information. `scrub.py` writes the scrubbed data to jsonl.xz files in the `data_scrubbed` directory and the filth data to json files in the `filth` directory.
3 |
4 | The filth data json files contain File Filth JSON objects with the following structure:
5 | ```
6 | {
7 | filename: name of the data file (dtype: string)
8 | filth_count: total count of filth detected across all documents in the data file (dtype: integer)
9 | word_count: total count of all words across all documents in the data file (dtype: integer)
10 | filth_doc_count: total count of documents with >0 filth in the data file (dtype: integer)
11 | doc_count: total count of documents in the data file (dtype: integer)
12 | filth_data: list of Document Filth JSON objects, corresponding to each document in the data file, order of Document Filth JSON objects perserves document order from original data file (dtype: list[Document Filth JSON object])
13 | }
14 | ```
15 |
16 | where a Document Filth JSON object has the following structure:
17 | ```
18 | {
19 | url: source url where document was scraped (dtype: string)
20 | filth_count: count of filth detected in document (dtype: integer)
21 | word_count: count of all words in document (dtype: integer)
22 | filth: list of Filth JSON objects, corresponding to each item of filth detected in the document (dtype: list[Filth JSON object])
23 | }
24 | ```
25 |
26 | and a Filth JSON object has the following structure:
27 | ```
28 | {
29 | type: type of filth, corresponds to type field in [scrubadub.filth.Filth](https://scrubadub.readthedocs.io/en/stable/api_scrubadub_filth.html) class (dtype: string or list[string])
30 | text: filth text / what was scrubbed out, corresponds to text field in [scrubadub.filth.Filth](https://scrubadub.readthedocs.io/en/stable/api_scrubadub_filth.html) class (dtype: string or list[string])
31 | merged: if is MergedFilth, filth is merged when multiple types of filth are detected for overlapping text, if merged = true, type and text will be lists (that share the same order) of all of the types and corresponding texts for which filth was detected (dtype: boolean)
32 | }
33 | ```
--------------------------------------------------------------------------------
/scrub/pii.py:
--------------------------------------------------------------------------------
1 | # Detects PII
2 | import glob
3 | import os
4 | import json
5 | import numpy as np
6 | import math
7 | import argparse
8 | from presidio_analyzer import AnalyzerEngine
9 | from presidio_anonymizer import AnonymizerEngine
10 | from tqdm import tqdm
11 | from datasets import load_dataset
12 | from multiprocessing import Pool, cpu_count
13 |
14 | # DATA DIRS
15 | DATA_URL = 'pile-of-law/pile-of-law'
16 | LOG_DIR = 'logs/pii'
17 |
18 | DATASETS = [
19 | 'bar_exam_outlines',
20 | 'fre',
21 | 'tos',
22 | 'ftc_advisory_opinions',
23 | 'frcp',
24 | 'constitutions',
25 | 'cfpb_creditcard_contracts',
26 | 'uscode',
27 | 'olc_memos',
28 | 'cc_casebooks',
29 | 'echr',
30 | 'federal_register',
31 | 'un_debates',
32 | 'euro_parl',
33 | 'scotus_oral_arguments',
34 | 'cfr',
35 | 'founding_docs',
36 | 'tax_rulings',
37 | 'r_legaladvice',
38 | 'eurlex',
39 | 'us_bills',
40 | 'nlrb_decisions',
41 | 'canadian_decisions',
42 | 'eoir',
43 | 'dol_ecab',
44 | # > 200MB
45 | 'oig',
46 | 'scotus_filings',
47 | 'state_codes',
48 | # > 1GB
49 | 'congressional_hearings',
50 | 'edgar',
51 | 'bva_opinions',
52 | 'courtlistener_docket_entry_documents',
53 | 'atticus_contracts',
54 | 'courtlistener_opinions',
55 | ]
56 |
57 | SPLITS = ['train', 'validation']
58 |
59 | # PII entities
60 | ENTITIES = ["US_BANK_NUMBER", "US_DRIVER_LICENSE", "US_ITIN", "US_PASSPORT", "US_SSN", "CREDIT_CARD", "PHONE_NUMBER", "EMAIL_ADDRESS"]
61 |
62 | # Threshold probabilities
63 | PII_THRESHOLD_PROB = 0.5
64 |
65 | MAX_LEN = int(1000000 / 10)
66 |
67 | # In characters
68 | CONTEXT_WINDOW = 20
69 |
70 | def scrub_doc(doc):
71 | analyzer, anonymizer = init_presidio()
72 |
73 | doc_text = doc["text"]
74 | doc_words = doc_text.split()
75 | doc_word_count = len(doc_words)
76 |
77 |
78 | def detect_pii(doc_text):
79 | analyzer_results = analyzer.analyze(text=doc_text, entities=ENTITIES, language='en')
80 | # Filter results to those with score >= PII_THRESHOLD_PROB
81 | # From manual inspection, PII_THRESHOLD_PROB = 0.5 seems reasonable
82 | results = []
83 | for result in analyzer_results:
84 | result = result.to_dict()
85 | if result['score'] >= PII_THRESHOLD_PROB:
86 | doc_len_chars = len(doc_text)
87 | context = ""
88 | if result['start'] - CONTEXT_WINDOW < 0 and result['end'] + CONTEXT_WINDOW + 1 > doc_len_chars:
89 | context = doc_text[0:doc_len]
90 | elif result['start'] - CONTEXT_WINDOW < 0:
91 | context = doc_text[0:result['end'] + CONTEXT_WINDOW + 1]
92 | elif result['end'] + CONTEXT_WINDOW + 1 > doc_len_chars:
93 | context = doc_text[result['start'] - CONTEXT_WINDOW:doc_len_chars]
94 | else:
95 | context = doc_text[result['start'] - CONTEXT_WINDOW:result['end'] + CONTEXT_WINDOW + 1]
96 | results.append({'type': result['entity_type'], 'span': doc_text[result['start']: result['end']], 'context': context, 'start': result['start'], 'end': result['end'], 'score': result['score']})
97 | return results
98 |
99 | # Detect PII
100 | try:
101 | doc_pii = detect_pii(doc_text)
102 | except ValueError as error:
103 | print(error)
104 |
105 | n = math.ceil(len(doc_text) / MAX_LEN)
106 | print(n)
107 | chunks = [doc_text[i:i+MAX_LEN] for i in range(0, len(doc_text), MAX_LEN)]
108 | doc_pii = []
109 | for chunk in chunks:
110 | doc_pii_chunk = detect_pii(chunk)
111 | doc_pii += doc_pii_chunk
112 |
113 | doc_pii_count = len(doc_pii)
114 |
115 | # Aggregate doc log data
116 | doc_log_data = {}
117 | doc_log_data["url"] = doc["url"]
118 | doc_log_data["word_count"] = doc_word_count
119 | doc_log_data["pii_count"] = doc_pii_count
120 | doc_log_data["pii"] = doc_pii
121 |
122 | return doc_log_data
123 |
124 |
125 | def scrub(split, name, dataset):
126 | docs_log_data = []
127 |
128 | # Global counts across the dataset
129 | pii_count = 0
130 | word_count = 0
131 | docs_with_pii = 0
132 | doc_count = 0
133 |
134 | results = None
135 | with Pool(processes=cpu_count() - 1) as p:
136 | results = list(tqdm(p.imap(scrub_doc, dataset), total=len(dataset)))
137 |
138 | for doc_log_data in results:
139 | word_count += doc_log_data["word_count"]
140 | pii_count += doc_log_data["pii_count"]
141 | if doc_log_data["pii_count"] > 0:
142 | docs_with_pii += 1
143 | doc_count += 1
144 |
145 | docs_log_data.append(doc_log_data)
146 |
147 | dataset_log_data = {}
148 | dataset_log_data["split"] = split
149 | dataset_log_data["name"] = name
150 | dataset_log_data["pii_count"] = pii_count
151 | dataset_log_data["word_count"] = word_count
152 | dataset_log_data["docs_with_pii"] = docs_with_pii
153 | dataset_log_data["doc_count"] = doc_count
154 | dataset_log_data["per_doc_log_data"] = docs_log_data
155 |
156 | return dataset_log_data
157 |
158 |
159 | def init_presidio():
160 | analyzer = AnalyzerEngine()
161 | anonymizer = AnonymizerEngine()
162 | return analyzer, anonymizer
163 |
164 |
165 | def save_json_file(data, out_dir, filename):
166 | if not os.path.exists(out_dir):
167 | os.makedirs(out_dir)
168 | filepath = os.path.join(out_dir, filename)
169 | with open(filepath, 'w') as out_file:
170 | json.dump(data, out_file)
171 |
172 |
173 | def main(args):
174 | split = args.split
175 | name = args.name
176 | print(f"{split}.{name}")
177 | dataset = load_dataset(DATA_URL, args.name, use_auth_token=True, split=args.split, streaming=False)
178 | dataset_log_data = scrub(split, name, dataset)
179 |
180 | # Save dataset log data (statistics on PII) to json
181 | save_json_file(dataset_log_data, LOG_DIR, f'{split}.{name}.json')
182 |
183 |
184 | if __name__ == '__main__':
185 | parser = argparse.ArgumentParser()
186 | parser.add_argument("--split")
187 | parser.add_argument("--name")
188 | args = parser.parse_args()
189 | main(args)
190 |
--------------------------------------------------------------------------------
/scrub/scrub.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os
3 | import json
4 | try:
5 | import lzma as xz
6 | except ImportError:
7 | import pylzma as xz
8 | import scrubadub
9 | from tqdm import tqdm
10 |
11 | # DATA DIRS
12 | IN_DIR = 'data'
13 | OUT_DIR = 'data_scrubbed'
14 | FILTH_DIR = 'filth'
15 |
16 | # remaining
17 | TO_SCRUB = ['train.atticus_contracts.jsonl.xz',
18 | 'validation.atticus_contracts.jsonl.xz',
19 | 'train.bva.jsonl.xz',
20 | 'validation.bva.jsonl.xz',
21 | 'train.congressional_hearings.xz',
22 | 'validation.congressional_hearings.xz',
23 | 'train.courtlistenerdocketentries.xz',
24 | 'validation.courtlistenerdocketentries.xz',
25 | 'train.courtlisteneropinions.xz',
26 | 'validation.courtlisteneropinions.xz',
27 | 'train.edgar.jsonl.xz',
28 | 'validation.edgar.jsonl.xz',
29 | ]
30 |
31 |
32 | def scrub(filepath, scrubber):
33 | """This function scrubs the data file of sensitive information.
34 | Returns
35 | - A list of the documents in the data file, with text scrubbed
36 | - A dict of the file filth data
37 | """
38 | print(f"Scrub {filepath}")
39 |
40 | print(f"Read {filepath}")
41 | lines = []
42 | with xz.open(filepath, mode='rb') as f:
43 | while True:
44 | try:
45 | line = f.readline().decode('utf-8')
46 | if line == "":
47 | break
48 | lines.append(line)
49 | except:
50 | print("corrupted line")
51 | break
52 |
53 |
54 | print(f"Clean {filepath}")
55 | docs_scrubbed = []
56 | docs_filth_data = []
57 |
58 | filth_count = 0
59 | word_count = 0
60 | filth_doc_count = 0
61 |
62 | for line in tqdm(lines):
63 | if line is not None and line != "":
64 | doc = json.loads(line)
65 | doc_text = doc["text"]
66 | doc_word_count = len(doc_text.split())
67 | word_count += doc_word_count
68 |
69 | # Detect filth
70 | filth_objs = list(scrubber.iter_filth(doc_text))
71 | doc_filth_count = len(filth_objs)
72 | filth_count += doc_filth_count
73 | if doc_filth_count > 0:
74 | filth_doc_count += 1
75 |
76 | doc_filth = []
77 | for filth_obj in filth_objs:
78 | filth = {}
79 | if isinstance(filth_obj, scrubadub.filth.base.MergedFilth): # filth_obj is MergedFilth
80 | types = []
81 | texts = []
82 | for filth_obj_i in filth_obj.filths:
83 | types.append(filth_obj_i.detector_name)
84 | texts.append(filth_obj_i.text)
85 | filth["type"] = types
86 | filth["text"] = texts
87 | filth["merged_text"] = filth_obj.text
88 | filth["merged"] = True
89 | else: # filth_obj is Filth
90 | filth["type"] = filth_obj.detector_name
91 | filth["text"] = filth_obj.text
92 | filth["merged"] = False
93 | doc_filth.append(filth)
94 |
95 | doc_filth_data = {}
96 | doc_filth_data["url"] = doc["url"]
97 | doc_filth_data["filth_count"] = doc_filth_count
98 | doc_filth_data["word_count"] = doc_word_count
99 | doc_filth_data["filth"] = doc_filth
100 |
101 | # Clean
102 | cleaned_text = scrubber.clean(doc_text)
103 | doc["text"] = cleaned_text
104 |
105 | docs_scrubbed.append(doc)
106 | docs_filth_data.append(doc_filth_data)
107 |
108 | file_filth_data = {}
109 | file_filth_data["filename"] = filepath.split('/')[-1]
110 | file_filth_data["filth_count"] = filth_count
111 | file_filth_data["word_count"] = word_count
112 | file_filth_data["filth_doc_count"] = filth_doc_count
113 | file_filth_data["doc_count"] = len(lines)
114 | file_filth_data["filth_data"] = docs_filth_data
115 |
116 | return docs_scrubbed, file_filth_data
117 |
118 |
119 | def init_scrubber():
120 | """Initialize scrubber with detectors."""
121 | detector_list = [scrubadub.detectors.en_US.SocialSecurityNumberDetector,
122 | scrubadub.detectors.PhoneDetector,
123 | scrubadub.detectors.EmailDetector,
124 | scrubadub.detectors.CreditCardDetector
125 | ]
126 | scrubber = scrubadub.Scrubber(detector_list=detector_list, locale='en_US')
127 | return scrubber
128 |
129 |
130 | def save_compressed_file(data, out_dir, filename):
131 | if not os.path.exists(out_dir):
132 | os.makedirs(out_dir)
133 | filepath = os.path.join(out_dir, filename)
134 | with xz.open(filepath, 'w') as out_file:
135 | for x in data:
136 | out_file.write((json.dumps(x) + "\n").encode("utf-8"))
137 |
138 |
139 | def save_json_file(data, out_dir, filename):
140 | if not os.path.exists(out_dir):
141 | os.makedirs(out_dir)
142 | filepath = os.path.join(out_dir, filename)
143 | with open(filepath, 'w') as out_file:
144 | json.dump(data, out_file)
145 |
146 |
147 | def main():
148 | filepaths_all = glob.glob(os.path.join(IN_DIR, "*.xz"))
149 | filepaths_to_scrub = [filepath for filepath in filepaths_all if filepath.split('/')[-1] in TO_SCRUB]
150 |
151 | scrubber = init_scrubber()
152 | for filepath in filepaths_to_scrub:
153 | docs_scrubbed, filth_data = scrub(filepath, scrubber)
154 |
155 | filename = filepath.split('/')[-1]
156 | trunc_filename = ".".join(filename.split('.')[0:2])
157 |
158 | # Save cleaned docs as compressed xz file
159 | save_compressed_file(docs_scrubbed, OUT_DIR, filename)
160 |
161 | # Save file filth data to json
162 | save_json_file(filth_data, FILTH_DIR, trunc_filename + ".json")
163 |
164 |
165 | if __name__ == '__main__':
166 | main()
--------------------------------------------------------------------------------
/toxicity/README.md:
--------------------------------------------------------------------------------
1 |
2 | # Processing Supreme Court Opinions with Toxicity APIs
3 |
4 | The scotus_only* files run toxicity scores over only scotus documents. Note: for this we use the case law access project supreme court decisions data because it contains accurate metadata for opinion dates. To run this, you will need to also download the bulk scotus opinions from CAP. The output of these scripts, a sentence by sentence scoring of all Supreme Court opinions can be found here:
5 |
6 | https://drive.google.com/drive/folders/1QvLdlBGHZYX6mv5HCh1SL_QZLCs5yN6-?usp=sharing
7 |
8 | This is roughly 11Gb of data.
9 |
10 | The data is formatted such that each row is a sentence indexed by document and a score from a different API. The mapping from sentence and document indices can be found in sent_mapping.pkl
11 |
12 | # Figure 2
13 |
14 | The Figure 2 generation code can be found in fig2
15 |
16 |
--------------------------------------------------------------------------------
/toxicity/create_doc_sent_index.py:
--------------------------------------------------------------------------------
1 | # Detects toxic language with unitariyai/detoxify
2 | from collections import defaultdict
3 | from tqdm.contrib.concurrent import process_map # or thread_map
4 | from multiprocess import set_start_method
5 | set_start_method('spawn', force=True)
6 |
7 |
8 | from collections import defaultdict
9 | import csv
10 | import glob
11 | import os
12 | import json
13 | import numpy as np
14 | import math
15 | import random
16 | import argparse
17 | from detoxify import Detoxify
18 | from tqdm import tqdm
19 | from datasets import load_dataset
20 | from multiprocessing import Pool, cpu_count
21 | import lexnlp.nlp.en.segments.sentences as lexnlp
22 | from googleapiclient import discovery
23 | import time
24 |
25 | from sys import getsizeof
26 |
27 | from transformers import pipeline
28 |
29 | base_dir = "./"
30 |
31 | PIPE = None
32 |
33 | DATA_URL = 'pile-of-law/pile-of-law'
34 |
35 |
36 | SPLITS = ['train', 'validation']
37 |
38 | # Threshold probabilities
39 | PROFANITY_THRESHOLD_PROB = 0.8
40 |
41 | MAX_LEN = int(1000000 / 10)
42 |
43 |
44 | def chunks(lst, n):
45 | """Yield successive n-sized chunks from lst."""
46 | for i in range(0, len(lst), n):
47 | yield lst[i:i + n]
48 |
49 |
50 | def save_json_file(data, out_dir, filename):
51 | if not os.path.exists(out_dir):
52 | os.makedirs(out_dir)
53 | filepath = os.path.join(out_dir, filename)
54 | with open(filepath, 'w') as out_file:
55 | json.dump(data, out_file)
56 |
57 | class NumpyEncoder(json.JSONEncoder):
58 | def default(self, obj):
59 | if isinstance(obj, np.ndarray):
60 | return obj.tolist()
61 | return json.JSONEncoder.default(self, obj)
62 |
63 | def get_opinions(stuff):
64 | stuff = json.loads(stuff)
65 | return {
66 | "text" : "\n".join(x["text"] for x in stuff["casebody"]["data"]["opinions"]),
67 | "decision_date" : stuff["decision_date"],
68 | "name" : stuff["name"],
69 | "docket_number" : stuff["docket_number"],
70 | "citations" : [x["cite"] for x in stuff["citations"]]
71 | }
72 |
73 | def main(args):
74 | split = args.split
75 | name = args.name
76 | print(f"{split}.{name}")
77 | results = []
78 | with open("data.jsonl", "r") as f:
79 | opinions = f.readlines()
80 | opinions = [get_opinions(x) for x in opinions]
81 | opinions = [x for x in opinions if len(x["text"]) > 3000]
82 | dates = [x["decision_date"] for x in opinions]
83 | names = [x["name"] for x in opinions]
84 |
85 | from torch.utils.data import Dataset
86 |
87 | class MyDataset(Dataset):
88 |
89 | def __init__(self) -> None:
90 | super().__init__()
91 | self.data = defaultdict(dict)
92 | for doc_idx, opinion in enumerate(opinions):
93 | sentences = lexnlp.get_sentence_list(opinion["text"])
94 | for sentence_idx, sentence in enumerate(sentences):
95 | self.data[doc_idx][sentence_idx] = {
96 | "sent" : sentence,
97 | "name" : opinion["name"],
98 | "docket_number" : opinion["docket_number"],
99 | "decision_date" : opinion["decision_date"],
100 | "citations" : opinion["citations"]
101 | }
102 |
103 | def __len__(self):
104 | return len(self.data)
105 |
106 | def __getitem__(self, i):
107 | return self.data[i]["text"]
108 |
109 |
110 | dataset = MyDataset()
111 | import pickle
112 | with open("sent_mapping.pkl", "wb") as f:
113 | pickle.dump(dataset.data, f)
114 |
115 |
116 |
117 |
118 |
119 |
120 | if __name__ == '__main__':
121 | parser = argparse.ArgumentParser()
122 | parser.add_argument("--split")
123 | parser.add_argument("--name")
124 |
125 | args = parser.parse_args()
126 | main(args)
127 |
--------------------------------------------------------------------------------
/toxicity/fig2/code/fig2.R:
--------------------------------------------------------------------------------
1 | #########################################
2 | #### Pile of Law, Figure 2 ####
3 | # Last updated: June 28 #
4 | #########################################
5 |
6 | #### SETUP ####
7 | # Load software
8 | rm(list=ls())
9 | library(psych)
10 | library(tidyverse)
11 | library(ggplot2)
12 | library(stringr)
13 | library(gridExtra)
14 |
15 | # Set root
16 | root <- "/Users/MarkKrass/Downloads/"
17 |
18 | ### Import Toxicity Scores
19 | # Get data: Profanity Checker
20 | pc <- read.csv(paste0(root, "pc_only_scores.csv"),
21 | colClasses = c("numeric","numeric",
22 | "factor","factor",
23 | "numeric","character","factor")) %>%
24 | mutate(year = as.numeric(str_extract(Date, "^[[:digit:]]{4}"))) %>%
25 | select(-Date)
26 |
27 | # Get data: Toxigen
28 | tx <- read.csv("/scratch/users/mkrass/pile/perspective_scores_toxigen_only.csv",
29 | stringsAsFactors = T)
30 |
31 | # Get data: Detoxify
32 | dt.max <- read.csv("/scratch/users/mkrass/pile/detoxify_scores.csv",
33 | stringsAsFactors = T )%>%
34 | group_by(DocIndex, SentenceIndex) %>%
35 | summarize(Score = max(Score))
36 |
37 | # Get data: Perspective
38 | pp <- read.csv("/scratch/users/mkrass/pile/perspective_scores_pc_only.csv",
39 | stringsAsFactors = T) %>%
40 | filter(Category %in% toupper(tox_classes)) %>%
41 | group_by(DocIndex, SentenceIndex) %>%
42 | summarize(Score = max(Score))
43 |
44 | ### Merge Toxicity Score
45 | pc <- tx %>% select(DocIndex,SentenceIndex,Score) %>% rename(toxigen=Score) %>%
46 | right_join(pc)
47 | pc <- dt.max %>% select(DocIndex,SentenceIndex,Score) %>% rename(detoxify=Score) %>%
48 | right_join(pc)
49 | pc <- pp %>% select(DocIndex,SentenceIndex,Score) %>% rename(perspective=Score) %>%
50 | right_join(pc)
51 |
52 |
53 | ### Load the Supreme Court Database (Spaeth, Epstein et al. 2021)
54 | load(paste0(root, "SCDB_2021_01_caseCentered_Citation.Rdata"))
55 | load(paste0(root, "SCDB_Legacy_07_caseCentered_Citation.Rdata"))
56 |
57 | # Get crosswalk of case names
58 | scdb.xw <- read.csv(paste0(root,"scdb_crosswalk.csv"), stringsAsFactors = T) %>%
59 | rename(Name=name) %>%
60 | mutate(year = as.numeric(substr(decision_date, start=1,stop=4)))
61 |
62 | # Rename
63 | scdb <- SCDB_2021_01_caseCentered_Citation %>% select(dateDecision, caseName, usCite, issueArea) ; rm(SCDB_2021_01_caseCentered_Citation)
64 | scdb.legacy <- SCDB_Legacy_07_caseCentered_Citation %>% select(dateDecision, caseName, usCite, issueArea); rm(SCDB_Legacy_07_caseCentered_Citation)
65 |
66 |
67 | pc <- pc %>%
68 | left_join(scdb.xw[,c("Name","year","usr","scdb")]) %>%
69 | distinct(DocIndex,SentenceIndex,.keep_all=T)
70 |
71 |
72 | pc <- pc %>% rename(caseId = scdb) %>% left_join(scdb)
73 |
74 | pcm.cr.long <- pc %>%
75 | select(year,DocIndex, issueArea,perspective,detoxify,Score,toxigen) %>%
76 | pivot_longer(perspective:toxigen,names_to="model",values_to="score2")
77 |
78 |
79 |
80 | # Save merged
81 | pc %>% saveRDS(file=paste0(root, "merged_maxes.RDS"))
82 |
83 | # (Optional: Load merged )
84 | #pc <- readRDS(paste0(root,"merged_maxes.RDS"))
85 |
86 |
87 | #### Obtain Cohen's K at Multiple Thresholds ####
88 |
89 | roundUp <- function(x){round(x/10)*10}
90 | out <- c()
91 | yrs <- c()
92 | w <- c()
93 | ts <- c()
94 | # Group case years into 10-year bins
95 | pc$yr_bin <- sapply(pc$year, roundUp)
96 |
97 | # At multiple thresholds, obtain Cohen's K
98 | for(t in c(0.5,0.8,0.9,0.95)){
99 | pc <- pc %>% mutate(
100 | detoxify.yn = ifelse(detoxify>t,1,0),
101 | profanitycheck.yn = ifelse(Score>t,1,0),
102 | toxigen.yn = ifelse(toxigen>t,1,0),
103 | perspective.yn = ifelse(perspective>t,1,0),
104 | iss_desc = case_when(issueArea == 2 ~ "Civil Rights",
105 | TRUE ~ "All Others"))
106 | # Exclude early years with few data points
107 | for(y in sort(unique(pc$yr_bin))[5:26]){
108 | # Focus on comparing Perspective and Profanity-Checker,
109 | # since these are especially popular
110 | dat <- pc %>%
111 | filter(yr_bin == y) %>%
112 | select(perspective.yn,profanitycheck.yn) %>%
113 | drop_na() %>%
114 | as.matrix()
115 | k <- cohen.kappa(x=as.matrix(dat))$kappa
116 | out <- c(out,k)
117 | yrs <- c(yrs,y)
118 | w <- c(w, nrow(dat))
119 | ts <- c(ts, t)
120 | }}
121 |
122 | # Collect data
123 | cpdat <- data.frame(k = out, year=yrs, w=w, ts=ts)
124 |
125 | #### Plot ####
126 |
127 | # Left panel: Cohen's K
128 | cohens <- cpdat %>%
129 | # Remove points with too few data points
130 | filter(ts %in% c(0.5),w>5000) %>%
131 | ggplot(aes(x=year,y=k)) +
132 | geom_point(aes(size=w)) +
133 | geom_line() + scale_size(guide="none") +
134 | ylab("Cohen's K") + xlim(1880,2010)
135 |
136 |
137 | # Right panel: issue scores by area over time
138 | plt.issue.area.share <- pcm.cr.long %>%
139 | mutate(iss_desc = case_when(issueArea == 2 ~ "Civil Rights",
140 | TRUE ~ "All Others"),
141 | model_desc = case_when(model == "Score" ~ "profanity_checker",
142 | TRUE ~ model),
143 | score.yn = ifelse(score2 > 0.5,1,0)) %>%
144 | group_by(year,model_desc,iss_desc) %>%
145 | summarize(score=mean(score.yn),n=n()) %>%
146 | ungroup() %>%
147 | ggplot(aes(x=year,y=score, color=model_desc)) +
148 | geom_smooth(span=0.3)+ylab("Share P(Toxic) > 0.5")+
149 | facet_wrap(~iss_desc) + xlim(1875,2022) + theme(legend.position = "right", legend.text = element_text(size = 6),
150 | legend.title = element_text(size = 8)) +
151 | guides(color = guide_legend(override.aes = list(size = 0.5)))
152 |
153 |
154 | ggsave(plt.issue.area.share, filename = paste0(root,"share_toxic_issue.pdf"),
155 | device="pdf",width=7,height=3)
156 |
157 |
158 | ## Assemble plots
159 | assembled <- grid.arrange(cohens, plt.issue.area.share, nrow=1, widths=c(0.3,0.7))
160 | ggsave(assembled, filename = paste0(root,"combined_cohens_toxic.pdf"),
161 | device="pdf",width=8,height=3)
162 |
163 |
164 |
--------------------------------------------------------------------------------
/toxicity/fig2/data/SCDB.Rdata:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Breakend/PileOfLaw/e6f7b11ba52e0fbbd392ad00bf3cc1632ae710b6/toxicity/fig2/data/SCDB.Rdata
--------------------------------------------------------------------------------
/toxicity/fig2/data/SCDBLegacy.Rdata:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Breakend/PileOfLaw/e6f7b11ba52e0fbbd392ad00bf3cc1632ae710b6/toxicity/fig2/data/SCDBLegacy.Rdata
--------------------------------------------------------------------------------
/toxicity/fig2/data/SCDB_2018_02_caseCentered_Citation.csv:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Breakend/PileOfLaw/e6f7b11ba52e0fbbd392ad00bf3cc1632ae710b6/toxicity/fig2/data/SCDB_2018_02_caseCentered_Citation.csv
--------------------------------------------------------------------------------
/toxicity/fig2/data/SCDB_2019_01_caseCentered_Citation.Rdata:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Breakend/PileOfLaw/e6f7b11ba52e0fbbd392ad00bf3cc1632ae710b6/toxicity/fig2/data/SCDB_2019_01_caseCentered_Citation.Rdata
--------------------------------------------------------------------------------
/toxicity/scotus_only_pc_only.py:
--------------------------------------------------------------------------------
1 | # Detects toxic language with unitariyai/detoxify
2 | from collections import defaultdict
3 | from tqdm.contrib.concurrent import process_map # or thread_map
4 |
5 |
6 | import csv
7 | import glob
8 | import os
9 | import json
10 | import numpy as np
11 | import math
12 | import random
13 | import argparse
14 | from profanity_check import predict_prob
15 | from detoxify import Detoxify
16 | from tqdm import tqdm
17 | from datasets import load_dataset
18 | from multiprocessing import Pool, cpu_count
19 | import lexnlp.nlp.en.segments.sentences as lexnlp
20 | from googleapiclient import discovery
21 | import time
22 | from apiclient.errors import HttpError
23 |
24 | from sklearn.feature_extraction.text import TfidfVectorizer
25 | from sklearn.linear_model import RidgeClassifier
26 | from sys import getsizeof
27 | from joblib import load
28 |
29 | from transformers import pipeline
30 |
31 | base_dir = "./"
32 |
33 |
34 | PIPE = None
35 |
36 |
37 | import os
38 |
39 | # DATA DIRS
40 | DATA_URL = 'pile-of-law/pile-of-law'
41 |
42 |
43 | SPLITS = ['train', 'validation']
44 |
45 | # Threshold probabilities
46 | PROFANITY_THRESHOLD_PROB = 0.8
47 |
48 | MAX_LEN = int(1000000 / 10)
49 |
50 |
51 |
52 | def profanity_check(sentences):
53 | res = predict_prob(sentences)
54 | return { i : {"PROFANITY" : res[i]} for i in range(len(sentences))}
55 |
56 | def exponential_backoff(req):
57 | for n in range(10):
58 | try:
59 | return client.comments().analyze(body=req).execute()
60 | except HttpError as error:
61 | if error.resp.reason.strip() not in ["Too Many Requests", "Resource has been exhausted (e.g. check quota)."]:
62 | print(error.resp.reason)
63 | import pdb; pdb.set_trace()
64 | raise
65 | if n < 9:
66 | time.sleep((random.random() * 5 * n) + random.random())
67 | print(f"BACKING OFF {n} for {error.resp.reason}")
68 | else:
69 | raise
70 | else:
71 | break
72 |
73 | def _conver_label_to_score(l):
74 | if l["label"] == "LABEL_0":
75 | return 1 - l["score"]
76 | else:
77 | return l["score"]
78 |
79 | def chunks(lst, n):
80 | """Yield successive n-sized chunks from lst."""
81 | for i in range(0, len(lst), n):
82 | yield lst[i:i + n]
83 |
84 |
85 |
86 | def scrub_doc(doc):
87 | doc_text = doc["text"]
88 |
89 | # Detect profanity
90 | # profanity-check
91 | sentences = lexnlp.get_sentence_list(doc_text)
92 |
93 | _return = {}
94 | if len(sentences) == 0:
95 | _return = {}
96 | _return["num_sentences"] = len(sentences)
97 | _return["examples"] = json.dumps({})
98 | _return["confusion"] = json.dumps(np.zeros((4,4)), cls=NumpyEncoder)
99 |
100 | results_pc = profanity_check(sentences)
101 |
102 | _return["sentences"] = json.dumps(sentences)
103 | _return["profanity_check_scores"] = json.dumps(results_pc)
104 | _return["num_sentences"] = len(sentences)
105 | return _return
106 |
107 |
108 | def save_json_file(data, out_dir, filename):
109 | if not os.path.exists(out_dir):
110 | os.makedirs(out_dir)
111 | filepath = os.path.join(out_dir, filename)
112 | with open(filepath, 'w') as out_file:
113 | json.dump(data, out_file)
114 |
115 | class NumpyEncoder(json.JSONEncoder):
116 | def default(self, obj):
117 | if isinstance(obj, np.ndarray):
118 | return obj.tolist()
119 | return json.JSONEncoder.default(self, obj)
120 |
121 | def get_opinions(stuff):
122 | stuff = json.loads(stuff)
123 | return {
124 | "text" : "\n".join(x["text"] for x in stuff["casebody"]["data"]["opinions"]),
125 | "decision_date" : stuff["decision_date"],
126 | "name" : stuff["name"]
127 | }
128 |
129 | def main(args):
130 | split = args.split
131 | name = args.name
132 | print(f"{split}.{name}")
133 | results = []
134 | with open("data.jsonl", "r") as f:
135 | opinions = f.readlines()
136 | opinions = [get_opinions(x) for x in opinions]
137 | opinions = [x for x in opinions if len(x["text"]) > 3000]
138 | dates = [x["decision_date"] for x in opinions]
139 | names = [x["name"] for x in opinions]
140 |
141 | print(len(opinions))
142 | from datasets import Dataset
143 | df = Dataset.from_dict({k: [dic[k] for dic in opinions] for k in opinions[0].keys()})
144 | import sys
145 | import traceback
146 | try:
147 | results2 = df.map(scrub_doc, num_proc=32)
148 | except Exception as e:
149 | import pdb; pdb.set_trace()
150 | del opinions
151 |
152 | total_sentences = 0.0
153 |
154 |
155 | with open(os.path.join(base_dir, "pc_only_scores.csv"), "w") as scores_f:
156 | spamwriter = csv.writer(scores_f)
157 | spamwriter.writerow(["DocIndex", "SentenceIndex", "API", "Category", "Score", "Date", "Name"])
158 | for results in [results2]:
159 | for i, (result, name, date) in enumerate(zip(results,names, dates)):
160 | for t in ["profanity_check", "perspective"]:
161 | if f"{t}_scores" not in result:
162 | continue
163 | scores = json.loads(result[f"{t}_scores"])
164 | # doc number
165 | for k, v in scores.items():
166 | for k2, v2 in v.items():
167 | datarow = (i, k, t, k2, v2, date, name)
168 | spamwriter.writerow(datarow)
169 |
170 |
171 |
172 |
173 | if __name__ == '__main__':
174 | parser = argparse.ArgumentParser()
175 | parser.add_argument("--split")
176 | parser.add_argument("--name")
177 |
178 | args = parser.parse_args()
179 | main(args)
180 |
--------------------------------------------------------------------------------
/toxicity/scotus_only_perspective_pc.py:
--------------------------------------------------------------------------------
1 | # Detects toxic language with unitariyai/detoxify
2 | from collections import defaultdict
3 | from tqdm.contrib.concurrent import process_map # or thread_map
4 |
5 |
6 | import csv
7 | import glob
8 | import os
9 | import json
10 | import numpy as np
11 | import math
12 | import random
13 | import argparse
14 | from profanity_check import predict_prob
15 | from detoxify import Detoxify
16 | from tqdm import tqdm
17 | from datasets import load_dataset
18 | from multiprocessing import Pool, cpu_count
19 | import lexnlp.nlp.en.segments.sentences as lexnlp
20 | from googleapiclient import discovery
21 | import time
22 | from apiclient.errors import HttpError
23 |
24 | from sklearn.feature_extraction.text import TfidfVectorizer
25 | from sklearn.linear_model import RidgeClassifier
26 | from sys import getsizeof
27 | from joblib import load
28 |
29 | from transformers import pipeline
30 |
31 | base_dir = "./"
32 |
33 |
34 | PIPE = None
35 |
36 |
37 | def toxigen_roberta():
38 | # This will load the pipeline on demand on the current PROCESS/THREAD.
39 | # And load it only once.
40 | global PIPE
41 | if PIPE is None:
42 | PIPE = pipeline("text-classification", model="tomh/toxigen_roberta", device=0)
43 | return PIPE
44 |
45 | import os
46 |
47 | API_KEY = ""
48 | # DATA DIRS
49 | DATA_URL = 'pile-of-law/pile-of-law'
50 |
51 |
52 | client = discovery.build(
53 | "commentanalyzer",
54 | "v1alpha1",
55 | developerKey=API_KEY,
56 | discoveryServiceUrl="https://commentanalyzer.googleapis.com/$discovery/rest?version=v1alpha1",
57 | static_discovery=False,
58 | )
59 |
60 | SPLITS = ['train', 'validation']
61 |
62 | # Threshold probabilities
63 | PROFANITY_THRESHOLD_PROB = 0.8
64 |
65 | MAX_LEN = int(1000000 / 10)
66 |
67 | def exponential_backoff(req):
68 | for n in range(10):
69 | try:
70 | return client.comments().analyze(body=req).execute()
71 | except HttpError as error:
72 | if error.resp.reason.strip() not in ["Too Many Requests", "Resource has been exhausted (e.g. check quota)."]:
73 | print(error.resp.reason)
74 | raise
75 | if n < 9:
76 | time.sleep((random.random() * 5 * n) + random.random())
77 | print(f"BACKING OFF {n} for {error.resp.reason}")
78 | else:
79 | raise
80 | else:
81 | break
82 |
83 | def _conver_label_to_score(l):
84 | if l["label"] == "LABEL_0":
85 | return 1 - l["score"]
86 | else:
87 | return l["score"]
88 |
89 | def chunks(lst, n):
90 | """Yield successive n-sized chunks from lst."""
91 | for i in range(0, len(lst), n):
92 | yield lst[i:i + n]
93 |
94 | def perspective_check_lexnlp(sentences):
95 |
96 | current_sentence_idx = 0
97 | results = defaultdict(dict)
98 | for sentence in sentences:
99 | # for idx, sentence in enumerate(sentences):
100 | # TODO: batching?
101 |
102 | analyze_request = {
103 | 'comment': { 'text': sentence.strip()},
104 | 'spanAnnotations' : False,
105 | "languages" : ["en"],
106 | 'requestedAttributes':{"TOXICITY": {}, "SEVERE_TOXICITY": {}, "IDENTITY_ATTACK" : {}, "INSULT" : {}, "PROFANITY" : {}, "THREAT" : {}}
107 | }
108 | try:
109 | response = exponential_backoff(analyze_request)
110 | except Exception as e:
111 | results[current_sentence_idx]["ERROR"] = 1.0
112 | print(f"ERROR ON SENTENCE {sentence} {e}")
113 |
114 | for k, v in response["attributeScores"].items():
115 | results[current_sentence_idx][k] = v["summaryScore"]['value']
116 | current_sentence_idx += 1
117 |
118 | return results
119 |
120 |
121 | def scrub_doc(doc):
122 | doc_text = doc["text"]
123 |
124 | sentences = lexnlp.get_sentence_list(doc_text)
125 | results_p = perspective_check_lexnlp(sentences)
126 | _return = {}
127 | if len(sentences) == 0:
128 | _return = {}
129 | _return["num_sentences"] = len(sentences)
130 | _return["examples"] = json.dumps({})
131 | _return["confusion"] = json.dumps(np.zeros((4,4)), cls=NumpyEncoder)
132 |
133 | _return["sentences"] = json.dumps(sentences)
134 | _return["perspective_scores"] = json.dumps(results_p)
135 | _return["num_sentences"] = len(sentences)
136 | return _return
137 |
138 |
139 | def save_json_file(data, out_dir, filename):
140 | if not os.path.exists(out_dir):
141 | os.makedirs(out_dir)
142 | filepath = os.path.join(out_dir, filename)
143 | with open(filepath, 'w') as out_file:
144 | json.dump(data, out_file)
145 |
146 | class NumpyEncoder(json.JSONEncoder):
147 | def default(self, obj):
148 | if isinstance(obj, np.ndarray):
149 | return obj.tolist()
150 | return json.JSONEncoder.default(self, obj)
151 |
152 | def get_opinions(stuff):
153 | stuff = json.loads(stuff)
154 | return {
155 | "text" : "\n".join(x["text"] for x in stuff["casebody"]["data"]["opinions"]),
156 | "decision_date" : stuff["decision_date"],
157 | "name" : stuff["name"]
158 | }
159 |
160 | def main(args):
161 | split = args.split
162 | name = args.name
163 | print(f"{split}.{name}")
164 | results = []
165 | with open("data.jsonl", "r") as f:
166 | opinions = f.readlines()
167 | opinions = [get_opinions(x) for x in opinions]
168 | opinions = [x for x in opinions if len(x["text"]) > 3000]
169 | dates = [x["decision_date"] for x in opinions]
170 | names = [x["name"] for x in opinions]
171 |
172 | print(len(opinions))
173 | from datasets import Dataset
174 | df = Dataset.from_dict({k: [dic[k] for dic in opinions] for k in opinions[0].keys()})
175 | results2 = df.map(scrub_doc, num_proc=32)
176 | del opinions
177 |
178 | with open(os.path.join(base_dir, "perspective_scores_pc_only.csv"), "w") as scores_f:
179 | spamwriter = csv.writer(scores_f)
180 | spamwriter.writerow(["DocIndex", "SentenceIndex", "API", "Category", "Score", "Date", "Name"])
181 | for results in [results2]:
182 | for i, (result, name, date) in enumerate(zip(results,names, dates)):
183 | for t in ["perspective"]:
184 | if f"{t}_scores" not in result:
185 | continue
186 | scores = json.loads(result[f"{t}_scores"])
187 | # doc number
188 | for k, v in scores.items():
189 | for k2, v2 in v.items():
190 | datarow = (i, k, t, k2, v2, date, name)
191 | spamwriter.writerow(datarow)
192 |
193 |
194 |
195 | if __name__ == '__main__':
196 | parser = argparse.ArgumentParser()
197 | parser.add_argument("--split")
198 | parser.add_argument("--name")
199 |
200 | args = parser.parse_args()
201 | main(args)
202 |
--------------------------------------------------------------------------------
/toxicity/scotus_only_toxigen.py:
--------------------------------------------------------------------------------
1 | # Detects toxic language with unitariyai/toxigen
2 | from collections import defaultdict
3 | from tqdm.contrib.concurrent import process_map # or thread_map
4 | from multiprocess import set_start_method
5 | set_start_method('spawn', force=True)
6 |
7 |
8 | import csv
9 | import glob
10 | import os
11 | import json
12 | import numpy as np
13 | import math
14 | import random
15 | import argparse
16 | from profanity_check import predict_prob
17 | from tqdm import tqdm
18 | from datasets import load_dataset
19 | from multiprocessing import Pool, cpu_count
20 | import lexnlp.nlp.en.segments.sentences as lexnlp
21 | from googleapiclient import discovery
22 | import time
23 | from apiclient.errors import HttpError
24 |
25 | from sklearn.feature_extraction.text import TfidfVectorizer
26 | from sklearn.linear_model import RidgeClassifier
27 | from sys import getsizeof
28 |
29 | from transformers import pipeline
30 |
31 | base_dir = "./"
32 |
33 |
34 | PIPE = None
35 |
36 | DATA_URL = 'pile-of-law/pile-of-law'
37 |
38 |
39 | SPLITS = ['train', 'validation']
40 |
41 | # Threshold probabilities
42 | PROFANITY_THRESHOLD_PROB = 0.8
43 |
44 | MAX_LEN = int(1000000 / 10)
45 |
46 |
47 | def chunks(lst, n):
48 | """Yield successive n-sized chunks from lst."""
49 | for i in range(0, len(lst), n):
50 | yield lst[i:i + n]
51 |
52 |
53 | def save_json_file(data, out_dir, filename):
54 | if not os.path.exists(out_dir):
55 | os.makedirs(out_dir)
56 | filepath = os.path.join(out_dir, filename)
57 | with open(filepath, 'w') as out_file:
58 | json.dump(data, out_file)
59 |
60 | class NumpyEncoder(json.JSONEncoder):
61 | def default(self, obj):
62 | if isinstance(obj, np.ndarray):
63 | return obj.tolist()
64 | return json.JSONEncoder.default(self, obj)
65 |
66 | def get_opinions(stuff):
67 | stuff = json.loads(stuff)
68 | return {
69 | "text" : "\n".join(x["text"] for x in stuff["casebody"]["data"]["opinions"]),
70 | "decision_date" : stuff["decision_date"],
71 | "name" : stuff["name"]
72 | }
73 |
74 | def toxigen_roberta():
75 | # This will load the pipeline on demand on the current PROCESS/THREAD.
76 | # And load it only once.
77 | global PIPE
78 | if PIPE is None:
79 | PIPE = pipeline("text-classification", model="tomh/toxigen_roberta", device=0, return_all_scores=True)
80 | return PIPE
81 |
82 | import os
83 |
84 | def main(args):
85 | set_start_method('spawn', force=True)
86 | split = args.split
87 | name = args.name
88 | print(f"{split}.{name}")
89 | results = []
90 | with open("data.jsonl", "r") as f:
91 | opinions = f.readlines()
92 | opinions = [get_opinions(x) for x in opinions]
93 | opinions = [x for x in opinions if len(x["text"]) > 3000]
94 | dates = [x["decision_date"] for x in opinions]
95 | names = [x["name"] for x in opinions]
96 |
97 | print(len(opinions))
98 | from torch.utils.data import Dataset
99 |
100 | class MyDataset(Dataset):
101 |
102 | def __init__(self) -> None:
103 | super().__init__()
104 | self.data = []
105 | for doc_idx, opinion in enumerate(opinions):
106 | sentences = lexnlp.get_sentence_list(opinion["text"])
107 | for sentence_idx, sentence in enumerate(sentences):
108 | self.data.append(
109 | {
110 | "sentence_idx" : sentence_idx,
111 | "doc_idx" : doc_idx,
112 | "text" : sentence
113 | })
114 | def __len__(self):
115 | return len(self.data)
116 |
117 | def __getitem__(self, i):
118 | return self.data[i]["text"]
119 |
120 |
121 | dataset = MyDataset()
122 |
123 |
124 | results1 = []
125 | cur_doc = defaultdict(dict)
126 | cur_doc_idx = 0
127 | try:
128 | for i, out in enumerate(tqdm(toxigen_roberta()(dataset, batch_size=32, truncation=True, max_length=512), total=len(dataset))):
129 | if cur_doc_idx != dataset.data[i]["doc_idx"]:
130 | cur_doc_idx = dataset.data[i]["doc_idx"]
131 | results1.append({"toxigen_scores" : json.dumps(cur_doc)})
132 | cur_doc = defaultdict(dict)
133 | cur_doc[dataset.data[i]["sentence_idx"]] = { x["label"] : x["score"] for x in out }
134 |
135 | if len(cur_doc) > 0:
136 | results1.append({"toxigen_scores" : json.dumps(cur_doc)})
137 | cur_doc = defaultdict(dict)
138 | except:
139 | import pdb; pdb.set_trace()
140 |
141 |
142 | with open(os.path.join(base_dir, "toxigen_scores.csv"), "w") as scores_f:
143 | spamwriter = csv.writer(scores_f)
144 | spamwriter.writerow(["DocIndex", "SentenceIndex", "API", "Category", "Score", "Date", "Name"])
145 | for results in [results1]:
146 | for i, (result, name, date) in enumerate(zip(results,names, dates)):
147 | for t in ["toxigen"]:
148 | if f"{t}_scores" not in result:
149 | continue
150 | scores = json.loads(result[f"{t}_scores"])
151 | # doc number
152 | for k, v in scores.items():
153 | for k2, v2 in v.items():
154 | datarow = (i, k, t, k2, v2, date, name)
155 | spamwriter.writerow(datarow)
156 |
157 |
158 | if __name__ == '__main__':
159 | parser = argparse.ArgumentParser()
160 | parser.add_argument("--split")
161 | parser.add_argument("--name")
162 |
163 | args = parser.parse_args()
164 | main(args)
165 |
--------------------------------------------------------------------------------
/toxicity/scotus_only_unitary.py:
--------------------------------------------------------------------------------
1 | # Detects toxic language with unitariyai/detoxify
2 | from collections import defaultdict
3 | from tqdm.contrib.concurrent import process_map # or thread_map
4 | from multiprocess import set_start_method
5 | set_start_method('spawn', force=True)
6 |
7 |
8 | import csv
9 | import glob
10 | import os
11 | import json
12 | import numpy as np
13 | import math
14 | import random
15 | import argparse
16 | from profanity_check import predict_prob
17 | from tqdm import tqdm
18 | from datasets import load_dataset
19 | from multiprocessing import Pool, cpu_count
20 | import lexnlp.nlp.en.segments.sentences as lexnlp
21 | import time
22 |
23 | from sklearn.feature_extraction.text import TfidfVectorizer
24 | from sklearn.linear_model import RidgeClassifier
25 | from sys import getsizeof
26 |
27 | from transformers import pipeline
28 |
29 | base_dir = "./"
30 |
31 |
32 | PIPE = None
33 |
34 | DATA_URL = 'pile-of-law/pile-of-law'
35 |
36 |
37 | SPLITS = ['train', 'validation']
38 |
39 | # Threshold probabilities
40 | PROFANITY_THRESHOLD_PROB = 0.8
41 |
42 | MAX_LEN = int(1000000 / 10)
43 |
44 |
45 | def chunks(lst, n):
46 | """Yield successive n-sized chunks from lst."""
47 | for i in range(0, len(lst), n):
48 | yield lst[i:i + n]
49 |
50 |
51 | def save_json_file(data, out_dir, filename):
52 | if not os.path.exists(out_dir):
53 | os.makedirs(out_dir)
54 | filepath = os.path.join(out_dir, filename)
55 | with open(filepath, 'w') as out_file:
56 | json.dump(data, out_file)
57 |
58 | class NumpyEncoder(json.JSONEncoder):
59 | def default(self, obj):
60 | if isinstance(obj, np.ndarray):
61 | return obj.tolist()
62 | return json.JSONEncoder.default(self, obj)
63 |
64 | def get_opinions(stuff):
65 | stuff = json.loads(stuff)
66 | return {
67 | "text" : "\n".join(x["text"] for x in stuff["casebody"]["data"]["opinions"]),
68 | "decision_date" : stuff["decision_date"],
69 | "name" : stuff["name"]
70 | }
71 |
72 | def unitary_roberta():
73 | # This will load the pipeline on demand on the current PROCESS/THREAD.
74 | # And load it only once.
75 | global PIPE
76 | if PIPE is None:
77 | PIPE = pipeline("text-classification", model="unitary/unbiased-toxic-roberta", device=0, return_all_scores=True)
78 | return PIPE
79 |
80 | import os
81 |
82 | def main(args):
83 | set_start_method('spawn', force=True)
84 | split = args.split
85 | name = args.name
86 | print(f"{split}.{name}")
87 | results = []
88 | with open("data.jsonl", "r") as f:
89 | opinions = f.readlines()
90 | opinions = [get_opinions(x) for x in opinions]
91 | opinions = [x for x in opinions if len(x["text"]) > 3000]
92 | dates = [x["decision_date"] for x in opinions]
93 | names = [x["name"] for x in opinions]
94 |
95 | print(len(opinions))
96 | from torch.utils.data import Dataset
97 |
98 | class MyDataset(Dataset):
99 |
100 | def __init__(self) -> None:
101 | super().__init__()
102 | self.data = []
103 | for doc_idx, opinion in enumerate(opinions):
104 | sentences = lexnlp.get_sentence_list(opinion["text"])
105 | for sentence_idx, sentence in enumerate(sentences):
106 | self.data.append(
107 | {
108 | "sentence_idx" : sentence_idx,
109 | "doc_idx" : doc_idx,
110 | "text" : sentence
111 | })
112 | def __len__(self):
113 | return len(self.data)
114 |
115 | def __getitem__(self, i):
116 | return self.data[i]["text"]
117 |
118 |
119 | dataset = MyDataset()
120 |
121 |
122 | results1 = []
123 | cur_doc = defaultdict(dict)
124 | cur_doc_idx = 0
125 | try:
126 | for i, out in enumerate(tqdm(unitary_roberta()(dataset, batch_size=32, truncation=True, max_length=512), total=len(dataset))):
127 | if cur_doc_idx != dataset.data[i]["doc_idx"]:
128 | cur_doc_idx = dataset.data[i]["doc_idx"]
129 | results1.append({"detoxify_scores" : json.dumps(cur_doc)})
130 | cur_doc = defaultdict(dict)
131 | cur_doc[dataset.data[i]["sentence_idx"]] = { x["label"] : x["score"] for x in out }
132 |
133 | if len(cur_doc) > 0:
134 | results1.append({"detoxify_scores" : json.dumps(cur_doc)})
135 | cur_doc = defaultdict(dict)
136 | except:
137 | import pdb; pdb.set_trace()
138 |
139 |
140 | with open(os.path.join(base_dir, "detoxify_scores.csv"), "w") as scores_f:
141 | spamwriter = csv.writer(scores_f)
142 | spamwriter.writerow(["DocIndex", "SentenceIndex", "API", "Category", "Score", "Date", "Name"])
143 | for results in [results1]:
144 | for i, (result, name, date) in enumerate(zip(results,names, dates)):
145 | for t in ["detoxify"]:
146 | if f"{t}_scores" not in result:
147 | continue
148 | scores = json.loads(result[f"{t}_scores"])
149 | # doc number
150 | for k, v in scores.items():
151 | for k2, v2 in v.items():
152 | datarow = (i, k, t, k2, v2, date, name)
153 | spamwriter.writerow(datarow)
154 |
155 |
156 | if __name__ == '__main__':
157 | parser = argparse.ArgumentParser()
158 | parser.add_argument("--split")
159 | parser.add_argument("--name")
160 |
161 | args = parser.parse_args()
162 | main(args)
163 |
--------------------------------------------------------------------------------
/toxicity/toxigen_context_exp.py:
--------------------------------------------------------------------------------
1 | # Detects toxic language with unitariyai/detoxify
2 | from collections import defaultdict
3 | from tqdm.contrib.concurrent import process_map # or thread_map
4 | from multiprocess import set_start_method
5 | set_start_method('spawn', force=True)
6 | import pandas as pd
7 |
8 | import csv
9 | import glob
10 | import os
11 | import json
12 | import numpy as np
13 | import math
14 | import random
15 | import pickle
16 | import argparse
17 | from profanity_check import predict_prob
18 | from detoxify import Detoxify
19 | from tqdm import tqdm
20 | from datasets import load_dataset
21 | from multiprocessing import Pool, cpu_count
22 | import lexnlp.nlp.en.segments.sentences as lexnlp
23 | from googleapiclient import discovery
24 | import time
25 | from apiclient.errors import HttpError
26 |
27 | from sklearn.feature_extraction.text import TfidfVectorizer
28 | from sklearn.linear_model import RidgeClassifier
29 | from sys import getsizeof
30 | from joblib import load
31 |
32 | from transformers import pipeline
33 |
34 | base_dir = "./"
35 |
36 | PIPE = None
37 |
38 |
39 | def toxigen_roberta():
40 | # This will load the pipeline on demand on the current PROCESS/THREAD.
41 | # And load it only once.
42 | global PIPE
43 | if PIPE is None:
44 | PIPE = pipeline("text-classification", model="tomh/toxigen_roberta", device=0)
45 | return PIPE
46 |
47 | import os
48 |
49 | # DATA DIRS
50 | DATA_URL = 'pile-of-law/pile-of-law'
51 |
52 |
53 | SPLITS = ['train', 'validation']
54 |
55 | # Threshold probabilities
56 | PROFANITY_THRESHOLD_PROB = 0.8
57 |
58 | MAX_LEN = int(1000000 / 10)
59 |
60 |
61 |
62 | def _conver_label_to_score(l):
63 | if l["label"] == "LABEL_0":
64 | return 1 - l["score"]
65 | else:
66 | return l["score"]
67 |
68 | def chunks(lst, n):
69 | """Yield successive n-sized chunks from lst."""
70 | for i in range(0, len(lst), n):
71 | yield lst[i:i + n]
72 |
73 | def toxigen(sentences):
74 | results = {}
75 | predictions = toxigen_roberta()(sentences)
76 | for i in range(len(sentences)):
77 | results[i] = { "TOXICITY" : _conver_label_to_score(predictions[i]) }
78 | return results
79 |
80 |
81 | def save_json_file(data, out_dir, filename):
82 | if not os.path.exists(out_dir):
83 | os.makedirs(out_dir)
84 | filepath = os.path.join(out_dir, filename)
85 | with open(filepath, 'w') as out_file:
86 | json.dump(data, out_file)
87 |
88 | class NumpyEncoder(json.JSONEncoder):
89 | def default(self, obj):
90 | if isinstance(obj, np.ndarray):
91 | return obj.tolist()
92 | return json.JSONEncoder.default(self, obj)
93 |
94 | def get_opinions(stuff):
95 | stuff = json.loads(stuff)
96 | return {
97 | "text" : "\n".join(x["text"] for x in stuff["casebody"]["data"]["opinions"]),
98 | "decision_date" : stuff["decision_date"],
99 | "name" : stuff["name"]
100 | }
101 |
102 | def main(args):
103 | toxigen_scores = pd.read_csv("./perspective_scores_toxigen_only.csv")
104 | toxigen_scores = toxigen_scores.nlargest(5000,['Score'])
105 | toxigen_scores = toxigen_scores[toxigen_scores["Score"] > .5]
106 | print(len(toxigen_scores))
107 | with open("./sent_mapping.pkl", "rb") as f:
108 | sentence_dict = pickle.load(f)
109 | from torch.utils.data import Dataset
110 |
111 | class MyDataset(Dataset):
112 |
113 | def __init__(self) -> None:
114 | super().__init__()
115 | self.data = []
116 | for i, row in sorted(toxigen_scores.iterrows(), key=lambda x: x[1]["Score"], reverse=True):
117 | sents = []
118 | low = max(0, row["SentenceIndex"]-2)
119 | doc_length = max(sentence_dict[row["DocIndex"]].keys())+1
120 | high = min(doc_length, row["SentenceIndex"]+2)
121 | for i in range(low, high):
122 | sents.append(sentence_dict[row["DocIndex"]][i]["sent"])
123 | sentence = " ".join(sents)
124 | self.data.append(
125 | {
126 | "sentence_idx" : row["SentenceIndex"],
127 | "doc_idx" : row["DocIndex"],
128 | "prev_score" : row["Score"],
129 | "text" : sentence,
130 | "prev_sentence" : sentence_dict[row["DocIndex"]][row["SentenceIndex"]]["sent"]
131 | })
132 | def __len__(self):
133 | return len(self.data)
134 |
135 | def __getitem__(self, i):
136 | return self.data[i]["text"]
137 |
138 |
139 | dataset = MyDataset()
140 |
141 |
142 | results1 = []
143 | sentences = []
144 | prev_sentences = []
145 | prev_scores = []
146 | cur_doc = defaultdict(dict)
147 | cur_doc_idx = 0
148 | for i, out in enumerate(tqdm(toxigen_roberta()(dataset, batch_size=32, truncation=True, max_length=512), total=len(dataset))):
149 | if cur_doc_idx != dataset.data[i]["doc_idx"]:
150 | cur_doc_idx = dataset.data[i]["doc_idx"]
151 | results1.append({"toxigen_scores" : json.dumps(cur_doc)})
152 | cur_doc = defaultdict(dict)
153 | cur_doc[dataset.data[i]["sentence_idx"]] = { "TOXICITY_DIFF" : _conver_label_to_score(out) - dataset.data[i]["prev_score"]}
154 | sentences.append(dataset.data[i]["text"])
155 | prev_sentences.append(dataset.data[i]["prev_sentence"])
156 | prev_scores.append(dataset.data[i]["prev_score"])
157 | if len(cur_doc) > 0:
158 | results1.append({"toxigen_scores" : json.dumps(cur_doc)})
159 | cur_doc = defaultdict(dict)
160 |
161 |
162 | with open(os.path.join(base_dir, "perspective_scores_toxigen_only_context_exp.csv"), "w") as scores_f:
163 | spamwriter = csv.writer(scores_f)
164 | spamwriter.writerow(["DocIndex", "SentenceIndex", "API", "Category", "Score", "PrevScore", "PrevSentence", "CurSentence"])
165 | for results in [results1]:
166 | for i, (result, prev_score, prev_sentence, cur_sentence) in enumerate(zip(results, prev_scores, prev_sentences, sentences)):
167 | for t in ["profanity_check", "perspective", "toxigen"]:
168 | if f"{t}_scores" not in result:
169 | continue
170 | scores = json.loads(result[f"{t}_scores"])
171 | # doc number
172 | for k, v in scores.items():
173 | for k2, v2 in v.items():
174 | datarow = (i, k, t, k2, v2, prev_score, prev_sentence, cur_sentence)
175 | spamwriter.writerow(datarow)
176 |
177 |
178 | if __name__ == '__main__':
179 | parser = argparse.ArgumentParser()
180 | parser.add_argument("--split")
181 | parser.add_argument("--name")
182 |
183 | args = parser.parse_args()
184 | main(args)
185 |
--------------------------------------------------------------------------------
|