├── .gitattributes ├── .gitignore ├── LICENSE ├── README.md ├── benchmarks └── .gitignore ├── environment.yml ├── img ├── corpora.jpg ├── diagram.png └── example.jpeg ├── output └── convincing_arguments.json ├── schema ├── README.md ├── datasets.yaml └── pairs.yaml └── scripts ├── README.md ├── add_pairs.py ├── build_benchmark.sh ├── create_cluster.py ├── generate.py ├── get_embeddings.py ├── make_benchmark.py ├── parameters.py ├── pull_data.py ├── scrape_admin_statements.py ├── scrape_blm_countermovements.py ├── scrape_open_review.py ├── scrape_twitter_rumors.py ├── test_discriminative.py └── utils.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | .DS_Store 6 | **/.DS_Store 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | pip-wheel-metadata/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | .python-version 88 | 89 | # pipenv 90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 93 | # install all needed dependencies. 94 | #Pipfile.lock 95 | 96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 97 | __pypackages__/ 98 | 99 | # Celery stuff 100 | celerybeat-schedule 101 | celerybeat.pid 102 | 103 | # SageMath parsed files 104 | *.sage.py 105 | 106 | # Environments 107 | .env 108 | .venv 109 | env/ 110 | venv/ 111 | ENV/ 112 | env.bak/ 113 | venv.bak/ 114 | 115 | # Spyder project settings 116 | .spyderproject 117 | .spyproject 118 | 119 | # Rope project settings 120 | .ropeproject 121 | 122 | # mkdocs documentation 123 | /site 124 | 125 | # mypy 126 | .mypy_cache/ 127 | .dmypy.json 128 | dmypy.json 129 | 130 | # Pyre type checker 131 | .pyre/ 132 | 133 | /webapp/src/firebase/config.js 134 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Peter Zhang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | OpenD5 2 | === 3 | 4 | Authors: Ruiqi Zhong, Peter Zhang, Steve Li, JinWoo Ahn, Dan Klein, Jacob Steinhardt 5 | 6 | [Paper link](https://arxiv.org/pdf/2302.14233.pdf) 7 | 8 | This repository hosts OpenD5, a benchmark for discovering natural language facts from pairs of corpora. Our paper focuses on the setting comparing two distributions of text via a text description. The repository containing the system is available [here](https://github.com/ruiqi-zhong/D5). 9 | 10 | ![](img/example.jpeg) 11 | 12 | The benchmark spans a wide array of disciplines and problem types. A sibling repostiory that contains code for running our system for solving these problems is available [here](https://github.com/ruiqi-zhong/D5). 13 | 14 | ![](img/corpora.jpg) 15 | 16 | To create the full benchmark, you should 1) downloaded these **[folders](https://drive.google.com/drive/folders/18N7-cwpXVtat9CocrMn_y7aI6Sbyo-Hn?usp=share_link)** and 2) run the [`build_benchmark.sh`](scripts/build_benchmark.sh) script from the main repo. 17 | 18 | For more details, please refer to the 19 | - [instructions](scripts/README.md) for using the scripts 20 | - [explanations](schema/README.md) of the relevant schema 21 | 22 | # Downloads 23 | 24 | - The 675 problems in the original paper are available [**here**](https://doi.org/10.5281/zenodo.7662705). 25 | - An extension with 37 additional problems is available **[here](https://drive.google.com/file/d/1meT_ewZrHRZM_VcAMos3eWHSYIMnlLED/view?usp=sharing)**. 26 | - A reproduction package for the entire dataset is available **[here](https://drive.google.com/drive/folders/18N7-cwpXVtat9CocrMn_y7aI6Sbyo-Hn?usp=share_link)**. It includes additional source data that is required to assemble the full dataset. 27 | 28 | # Contributing 29 | 30 | If you'd like to contribute additonal problems to the benchmark, please: 31 | - Create a script for constructing various splits for the dataset (see [`pull_data.py`](scripts/pull_data.py)). 32 | - Add the dataset's relevant metadata to the [`datasets.yaml`](schema/datasets.yaml) and [`pairs.yaml`](schema/pairs.yaml) schema. 33 | - Create a pull request and list the relevant citation. 34 | - Email petez@berkeley.edu with any questions. 35 | 36 | # BibTeX 37 | 38 | ```bibtex 39 | @article{zhong2023goal, 40 | title={Goal Driven Discovery of Distributional Differences via Language Descriptions}, 41 | author={Zhong, Ruiqi and Zhang, Peter and Li, Steve and Ahn, Jinwoo and Klein, Dan and Steinhardt, Jacob}, 42 | journal={arXiv preprint arXiv:2302.14233}, 43 | year={2023} 44 | } 45 | ``` 46 | -------------------------------------------------------------------------------- /benchmarks/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: opend5 2 | channels: 3 | - defaults 4 | dependencies: 5 | - tqdm 6 | - nltk 7 | - pandas 8 | - pyyaml 9 | - pytorch 10 | - transformers 11 | - numpy 12 | - python 13 | - markdown 14 | - scikit-learn 15 | - gdown 16 | - tweepy 17 | - pdfplumber 18 | - gzip -------------------------------------------------------------------------------- /img/corpora.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/petezh/OpenD5/096c88674344e22593173c2611364b7cbfc01f4b/img/corpora.jpg -------------------------------------------------------------------------------- /img/diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/petezh/OpenD5/096c88674344e22593173c2611364b7cbfc01f4b/img/diagram.png -------------------------------------------------------------------------------- /img/example.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/petezh/OpenD5/096c88674344e22593173c2611364b7cbfc01f4b/img/example.jpeg -------------------------------------------------------------------------------- /schema/README.md: -------------------------------------------------------------------------------- 1 | OpenD5 Schema 2 | === 3 | 4 | Each problem is [represented](https://github.com/ruiqi-zhong/D5#problem-representation) as a combination of a pair of corpora and a research goal. See below for an example. 5 | 6 | ![](../img/example.jpeg) 7 | 8 | [`pairs.yaml'](pairs.yaml) contains most of the metadata: 9 | - The highest level of structure is each dataset's abbreviation (e.g. `abc_headlines`) 10 | - Each dataset contains several *generations*, which describe different ways of splitting a dataset. 11 | - A generation description is comprised of the splitting feature (e.g. `the year they were published`). 12 | - Applications are comprised of a `target` and `user` with some `example_hypotheses`. 13 | - The `v2-origid` field is for internal tracking and is non-essential. 14 | - If the context should allow for any kind of hypothesis, the `purely_exploratory` flag is set to `True`. 15 | - The `pair_type` is assigned according to the taxonomy outlined below. 16 | - If a `flip` is appropriate, the user should consider both the original and swapped (B vs. A) versions of the problem. 17 | - Each generation contains a list of `pairs`. 18 | - Each pair has `pos_desc` and `neg_desc` (e.g. `are ABC news headlines from 2007`) 19 | - The classes correspond to distribution names. 20 | 21 | [`dataset.yaml`](datasets.yaml) contains dataset-level features: 22 | - A description of the types of text samples, e.g. `headlines published by ABC news, an American news company`. 23 | - The `discipline` and `expertise` required for the dataset. 24 | - The `status` of the dataset, which will be: 25 | - `public` if there is a public license. 26 | - `private` if the dataset was privately shared. 27 | - `accessible` if the dataset can be accessed easily but has unclear license. 28 | - The preprocessing steps, which are mostly for internal tracking. -------------------------------------------------------------------------------- /schema/datasets.yaml: -------------------------------------------------------------------------------- 1 | abc_headlines: 2 | description: headlines published by ABC news, an American news company 3 | discipline: social science 4 | expertise: none 5 | preprocessing: ABC headlines are directly downloaded from Harvard Dataverse. The 6 | year is extracted from the publication date field. Samples are constructed from 7 | the headline text. 8 | status: public 9 | url: https://www.kaggle.com/therohk/million-headlines/data 10 | ad_transcripts: 11 | description: ad scripts from a variety of industries 12 | discipline: business 13 | expertise: none 14 | preprocessing: Ad transcripts are directly downloaded from Kaggle. The top eight 15 | industries by frequency are selected. Newlines are replaced with spaces. 16 | status: public 17 | url: https://www.kaggle.com/datasets/kevinhartman0/advertisement-transcripts-from-various-industries 18 | admin_statements: 19 | description: statements of administration policy from American presidents 20 | discipline: social science 21 | expertise: familiarity with American policy 22 | preprocessing: Administration statements are extracted from a collection hosted 23 | on GitHub. Extraneous symbols are removed and samples are split by paragraph. 24 | status: public 25 | url: https://github.com/unitedstates/statements-of-administration-policy 26 | ai2_natural_instruction: 27 | description: a learning-from-instructions dataset released by the Allen Institute 28 | for AI 29 | discipline: machine learning 30 | expertise: none 31 | preprocessing: Natural instruction tasks are directly downloaded without modification. 32 | status: public 33 | url: https://instructions.apps.allenai.org/ 34 | airline_reviews: 35 | description: reviews of airlines collected from the review website Skytrax 36 | discipline: social science 37 | expertise: none 38 | preprocessing: Airline reviews for airlines, airports, and seats are downloaded 39 | from a public Github repository. Names of aircrafts, airlines, countries, and 40 | traveler types are standardized. Ratings of 1, 4, or 5 on a scale of 5, and 1, 41 | 5, 8, or 10 on a scale of 10 are kept. 42 | status: public 43 | url: https://github.com/quankiquanki/skytrax-reviews-dataset 44 | aita: 45 | description: posts on the "Am I The Asshole" Subreddit, an online forum people ask 46 | others whether they were in the wrong 47 | discipline: social science 48 | expertise: none 49 | preprocessing: Posts from r/AmITheAsshole are downloaded from a praw scrape of Reddit. 50 | Topic areas are chosen based on common themes in posts and coarsely defined based 51 | on manual keywords. Each post can belong to multiple topic areas. 52 | status: public 53 | url: https://github.com/iterative/aita_dataset 54 | all_the_news: 55 | description: news articles collected from various outlets between 2015 and 2017 56 | discipline: social science 57 | expertise: none 58 | preprocessing: News articles are downloaded directly from Components website. The 59 | titles are used as text samples. 60 | status: public 61 | url: https://www.kaggle.com/datasets/snapcrack/all-the-news 62 | amazon_reviews: 63 | description: Amazon reviews collected from various product categories 64 | discipline: business 65 | expertise: none 66 | preprocessing: Amazon reviews are downloaded from a 2018 crawl of the website. The 67 | first 100,000 review texts are treated as the text sample. 68 | status: public 69 | url: https://nijianmo.github.io/amazon/index.html 70 | armenian_jobs: 71 | description: job postings in Armenia 72 | discipline: business 73 | expertise: none 74 | preprocessing: Armenian job postings dataset is downloaded from a snapshot on GitHub. 75 | Different IT jobs are manually coded and time intervals are defined in order to 76 | balance sample availability. 77 | status: public 78 | url: https://www.kaggle.com/datasets/udacity/armenian-online-job-postings 79 | blm_countermovements: 80 | description: Tweets about the All Lives Matter, Blue Lives Matter, and White Lives 81 | Matter movements 82 | discipline: humanities 83 | expertise: Familiarity with the BLM movement 84 | preprocessing: Tweet IDs are downloaded from the original paper and, where available, 85 | collected from the current API. Due to API rate limits, only 1,000 Tweets are 86 | sampled from each movement. 87 | status: public 88 | url: https://databank.illinois.edu/datasets/IDB-9614170 89 | blogs: 90 | description: blog posts scraped from blogger.com in August of 2014 91 | discipline: social science 92 | expertise: none 93 | preprocessing: Blogs are downloaded directly from Kaggle, and the first 1 million 94 | blog posts are kept. 95 | status: public 96 | url: https://www.kaggle.com/datasets/rtatman/blog-authorship-corpus 97 | boolq: 98 | description: a reading comprehension dataset of yes/no questions 99 | discipline: machine learning 100 | expertise: none 101 | preprocessing: Boolean questions are downloaded directly as is. 102 | status: public 103 | url: https://github.com/google-research-datasets/boolean-questions 104 | cah: 105 | description: cards from games of Cards Against Humanity, an adult party game 106 | discipline: humanities 107 | expertise: none 108 | preprocessing: Cards Against Humanities plays are retrieved from direct correspondence 109 | with authors and accessed locally. 110 | status: private 111 | url: https://paperswithcode.com/dataset/cards-against-humanity 112 | clickbait_headlines: 113 | description: headlines across time from the Examiner, a clickbait news site 114 | discipline: social science 115 | expertise: none 116 | preprocessing: The Examiner headlines are directly downloaded from Kaggle. The year 117 | is extracted from the publication date field. Samples are constructed from the 118 | headline text. 119 | status: public 120 | url: https://www.kaggle.com/datasets/therohk/examine-the-examiner 121 | convincing_arguments: 122 | description: arguments on a variety of topics annotated for convincingness 123 | discipline: humanities 124 | expertise: Understanding of debate topics 125 | preprocessing: Annotated arguments are downloaded from the GitHub repository. Arguments 126 | are sorted by rank. The bottom 400 are treated as "unconvincing", the top 200 127 | are treated as "convincing", and the next 200 are treated as "somewhat convincing." 128 | status: public 129 | url: https://github.com/UKPLab/acl2016-convincing-arguments 130 | craigslist_negotiations: 131 | description: dialogue from Craigslist negotiations, an online seller platform 132 | discipline: humanities 133 | expertise: none 134 | preprocessing: Craigslist negotiations are downloaded from Huggingface. Sequences 135 | which contained a "quit" intention or "reject" intention are categorized as failures; 136 | those which contained an "accept" intention are categorized as successes. The 137 | mid-price is defined as the mean price of the items sold. Within each category, 138 | the items are sorted by mid-price. The top half is treated as high-price and the 139 | bottom half is treated as low-price. 140 | status: public 141 | url: https://huggingface.co/datasets/craigslist_bargains 142 | debate: 143 | description: evidence compiled for American competitive policy debate, published 144 | online by debate camps 145 | discipline: humanities 146 | expertise: Debate knowledge 147 | preprocessing: The train split is downloaded from Huggingface. For each sample, 148 | we use the abstract as the text. Arguments are categorized by type, debate camp 149 | of origin, and topic/specific argument. For topics, we use domain knowledge to 150 | list relevant keywords for each topic and include any sample with a file name 151 | that includes any keyword. A single sample can belong to multiple topics. 152 | status: public 153 | url: https://huggingface.co/datasets/Hellisotherpeople/DebateSum 154 | dice_jobs: 155 | description: American technology job postings on dice.com 156 | discipline: business 157 | expertise: none 158 | preprocessing: Job postings are downloaded from Kaggle. Posts from the six most 159 | popular companies are categorized by company. We remove miscellaneous characters 160 | and blank descriptions. We additionally apply our splitting procedure to reduce 161 | description length. 162 | status: public 163 | url: https://www.kaggle.com/datasets/PromptCloudHQ/us-technology-jobs-on-dicecom 164 | diplomacy_deception: 165 | description: diaglogue from games of Diplomacy, which involves deception 166 | discipline: social science 167 | expertise: Familiarity with the game Diplomacy 168 | preprocessing: Diplomacy dialogues are downloaded from GitHub (all splits). The 169 | data are ASCII encoded and newlines are removed. Each message and label is treated 170 | as a sample. 171 | status: public 172 | url: https://huggingface.co/datasets/diplomacy_detection 173 | drug_experiences: 174 | description: self-reports of various illicit drugs from Erowid.com 175 | discipline: health 176 | expertise: Familiarity with drug-related language 177 | preprocessing: Drug experiences are downloaded from GitHub repository. For each 178 | sample, we remove HTML formatting, split samples by paragraphs, and keep only 179 | paragraphs with over 50 characters. 180 | status: private 181 | url: https://github.com/technillogue/erowid-w2v 182 | echr_decisions: 183 | description: facts of cases heard before the European Court of Human Rights 184 | discipline: social science 185 | expertise: Ability to understand legalese 186 | preprocessing: Decisions are downloaded from a public archive. A random sample of 187 | 500 decisions are selected from the files. The samples with any violated articles 188 | are categorized as "violation," while the rest are categorized as "no violation." 189 | status: public 190 | url: https://paperswithcode.com/dataset/echr 191 | essay_scoring: 192 | description: essays from students 193 | discipline: social science 194 | expertise: none 195 | preprocessing: Essays are downloaded from a GitHub repository. Only essays from 196 | set 5 are considered. Essays with a score of at least 3 are categorized as good 197 | essays, while essays with a score less than 3 are bad essays. 198 | status: public 199 | url: https://www.kaggle.com/c/asap-aes 200 | fake_news: 201 | description: fake and legitimate news 202 | discipline: social science 203 | expertise: none 204 | preprocessing: Fake news articles are downloaded from the author's website. Full 205 | articles are treated as text snippets. 206 | status: public 207 | url: http://web.eecs.umich.edu/~mihalcea/downloads.html#FakeNews 208 | fomc_speeches: 209 | description: Federal Open Market Committee (FOMC) speeches from 1996-2020, which 210 | describe Federal Reserve policy 211 | discipline: social science 212 | expertise: Familiarity with Federal Reserve policy 213 | preprocessing: Fed speeches are downloaded from Kaggle. The macro indicator data 214 | are merged in on the year and month. Full speech text is split by paragraph and 215 | categorized by speaker, year, and macroeconomic indicator. 216 | status: public 217 | url: https://www.kaggle.com/datasets/natanm/federal-reserve-governors-speeches-1996-2020 218 | genius_lyrics: 219 | description: lyrics collected from Genius.com before 2020 220 | discipline: humanities 221 | expertise: none 222 | preprocessing: Genius lyrics are downloaded from a Google Drive. The lyrics are 223 | merged with song metadata and treated as samples. We categorize lyrics by hand-selecting 224 | popular artists, common genres, time periods, and view counts (over 1M views is 225 | high, 500k-1M is medium). 226 | status: public 227 | url: https://www.cs.cornell.edu/~arb/data/genius-expertise/ 228 | happy_moments: 229 | description: self-reported happy moments and demographic characteristics 230 | discipline: social science 231 | expertise: none 232 | preprocessing: The HappyDB dataset is downloaded from the official GitHub repository. 233 | Demographic data is cleaned and merged into the happy moments. Happy moment descriptions 234 | are treated as samples and are categorized by type of happy moment, country of 235 | origin, and other demographic features. 236 | status: public 237 | url: https://github.com/megagonlabs/HappyDB 238 | huff_post_headlines: 239 | description: headlines from the news outlet Huffington Post 240 | discipline: social science 241 | expertise: none 242 | preprocessing: Huffington Post headlines are downloaded from Kaggle. The short description 243 | of each article is treated as a sample and tokenized at the sentence level. 244 | status: public 245 | url: https://rishabhmisra.github.io/publications/ 246 | immigration_speeches: 247 | description: congressional and presidential speeches that mention immigration from 248 | 1880 to the present 249 | discipline: social science 250 | expertise: none 251 | preprocessing: Immigration speeches are downloaded from the replication package. 252 | The speech text is preprocessed to remove extraneous spaces. We engineer features 253 | corresponding to time periods, well-known speakers, other significant time periods, 254 | racial group under discussion, and geographic area of the United States. 255 | status: public 256 | url: https://github.com/dallascard/us-immigration-speeches/releases 257 | kickstarter: 258 | description: names of startups on kickstarter.com 259 | discipline: business 260 | expertise: none 261 | preprocessing: We download a 2018 crawl from Kickstarter from Kaggle. The project 262 | name is treated as the text sample. 263 | status: public 264 | url: https://www.kaggle.com/datasets/kemical/kickstarter-projects?select=ks-projects-201612.csv 265 | microedit_humor: 266 | description: funny sentences generated by making one-word edits to normal statements 267 | discipline: social science 268 | expertise: none 269 | preprocessing: Microedit dataset is downloaded from the author's website. We make 270 | the relevant edit to each text sample and treat the edited text sample as the 271 | data point. We bin the mean annotator grade into 4 and denote each as unfunny, 272 | neutral, funny, and very funny, respectively. 273 | status: public 274 | url: https://paperswithcode.com/dataset/humicroedit 275 | mnli: 276 | description: a collection of sentence pairs annotated with textual entailment information 277 | from a range of genres 278 | discipline: aarning 279 | expertise: none 280 | preprocessing: The MNLI corpus is downloaded from the official website. We treat 281 | the premise and hypothesis as text samples. 282 | status: public 283 | url: https://cims.nyu.edu/~sbowman/multinli/ 284 | monster_jobs: 285 | description: American job postings on monster.com 286 | discipline: business 287 | expertise: none 288 | preprocessing: Jobs on Monster.com are downloaded from Kaggle. Job descriptions 289 | are treated as samples and split at the paragraph and sentence level. We keep 290 | and categorize jobs from seventeen large cities. 291 | status: public 292 | url: https://www.kaggle.com/datasets/PromptCloudHQ/us-jobs-on-monstercom 293 | movie_tmdb: 294 | description: movie plot summaries from TMDB 295 | discipline: business 296 | expertise: none 297 | preprocessing: TMDB movie overviews are downloaded from Kaggle. We keep only English 298 | movies and bin popularity by deciles. The top decile is considered "hits," the 299 | 70-80th percentiles are considered "average," and the 30-40th percentiles are 300 | considered "bad." 301 | status: accessible 302 | url: https://www.themoviedb.org/ 303 | movie_wiki: 304 | description: movie plot summaries collected from Wikipedia 305 | discipline: social science 306 | expertise: none 307 | preprocessing: Wikipedia movie summaries are downloaded from Kaggle. 308 | status: public 309 | url: https://www.kaggle.com/datasets/jrobischon/wikipedia-movie-plots 310 | news_popularity: 311 | description: news headlines posted on social media platforms 312 | discipline: business 313 | expertise: none 314 | preprocessing: Headlines are downloaded from a reproduction package. Headline and 315 | title text are cleaned, and the title is treated as the text sample. The 100 most 316 | positive and negative or popular and unpopular articles on each topic are used 317 | as distributions. 318 | status: public 319 | url: https://archive.ics.uci.edu/ml/datasets/News+Popularity+in+Multiple+Social+Media+Platforms 320 | nli_benchmarks: 321 | description: training examples from various natural language inference (NLI) datasets 322 | discipline: machine learning 323 | expertise: none 324 | preprocessing: NLI benchmarks are downloaded from a public collection on Google 325 | Drive. We examine the premise and hypothesis separately as samples. 326 | status: public 327 | url: https://github.com/alisawuffles/wanli 328 | npt_conferences: 329 | description: Non-Proliferation of Nuclear Weapons (NPT) conference transcripts 330 | discipline: social science 331 | expertise: Familiarity with nuclear policy 332 | preprocessing: 'NPT conference notes are extracted from the accompanying replication 333 | package. Text is split by paragraph, and only paragraphs longer than 50 characters 334 | are preserved. Text is split into three time ranges: pre-2008, 2008-2012, and 335 | post-2012.' 336 | status: public 337 | url: https://journals.sagepub.com/doi/full/10.1177/0022343320960523 338 | open_deception: 339 | description: arbitrary lies and truths from any domain generated by crowdworkers 340 | discipline: social science 341 | expertise: none 342 | preprocessing: Open domain lies are downloaded from the public dataset and lie texts 343 | are split into lies and truths. 344 | status: public 345 | url: https://web.eecs.umich.edu/~mihalcea/downloads.html#OpenDeception 346 | open_review: 347 | description: submissions to ICLR, a machine learning conference from 2018 to 2021 348 | discipline: machine learning 349 | expertise: familiarity with machine learning 350 | preprocessing: 'Open review abstracts are accessed via the openreview API. We query 351 | for abstracts from the 2018-2021 ICLR blind submissions. Abstracts are classified 352 | based on rating: >=7 ("great"), 5-6 ("good"), and <=4 ("bad").' 353 | status: scraped 354 | url: https://openreview.net/ 355 | oral_histories: 356 | description: oral histories from the United States 357 | discipline: humanities 358 | expertise: none 359 | preprocessing: Oral histories are downloaded from the paper's accompanying GitHub 360 | repository. Histories are classified according to birth year of the author (pre-1930, 361 | 1930-1949, post-1950), the race of the speaker (black, Asian, white), college 362 | education (graduate/bachelors or none), and place of birth (South or not South, 363 | as defined by the Census Bureau). We treat the full oral history as the text sample. 364 | status: private 365 | url: https://github.com/ohtap/ohtap 366 | parenting_reddit_users: 367 | description: posts from individual parents in parenting related forums on Reddit 368 | discipline: social science 369 | expertise: none 370 | preprocessing: Individual posts are retrieved with permission from the author. We 371 | sample 5,000 posts per year. When use authorship histories to estimate how long 372 | each author has been posting on parenting related subreddits and split according 373 | to various account ages. We use posts on mom- and dad- related subreddits to guess 374 | user gender and split accordingly. 375 | status: private 376 | url: https://github.com/SALT-NLP/Parenting_OnlineUsage 377 | parenting_subreddits: 378 | description: posts from various parenting-related subreddits, which are text-based 379 | forums on the site Reddit 380 | discipline: social science 381 | expertise: none 382 | preprocessing: Posts from various subreddits are downloaded from the paper's GitHub 383 | repository. We clean the text and split the posts according to the topic(s) each 384 | post is tagged with. 385 | status: public 386 | url: https://github.com/SALT-NLP/Parenting_OnlineUsage 387 | poetry: 388 | description: poems from PoetryFoundation.com 389 | discipline: humanities 390 | expertise: literary knowledge 391 | preprocessing: Poems are downloaded from a 2019 scrape of the PoetryFoundation website 392 | from Kaggle. The text is cleaned and split according subject tags and authorship. 393 | status: public 394 | url: https://www.kaggle.com/datasets/tgdivy/poetry-foundation-poems 395 | political_ads: 396 | description: political ads observed by Facebook users 397 | discipline: social science 398 | expertise: none 399 | preprocessing: Ads are downloaded from the Ad Observer website, which maintains 400 | an aggregate of all collected ads. We extract targeting metadata from the targeting 401 | field and define splits according to age, gender, location, interests, time, and 402 | political lean. 403 | status: public 404 | url: https://adobserver.org/ 405 | politifact: 406 | description: fact-checks from the popular fact check website Politifact 407 | discipline: soical science 408 | expertise: none 409 | preprocessing: .nan 410 | status: public 411 | url: https://www.kaggle.com/datasets/rmisra/politifact-fact-check-dataset 412 | qqp: 413 | description: questions from Quora.com 414 | discipline: machine learning 415 | expertise: none 416 | preprocessing: .nan 417 | status: public 418 | url: https://paperswithcode.com/dataset/quora-question-pairs 419 | radiology_diagnosis: 420 | description: impressions and medical histories of radiology patients 421 | discipline: health 422 | expertise: Medical 423 | preprocessing: Radiology diagnoses are downloaded from a GitHub copy of the original 424 | task dataset. We parse the metadata to retrieve the diagnostic code, decision 425 | type, impression, and patient history. Referencing the associated ICD codes, we 426 | convert codes to colloquial diagnoses (e.g. 786.2 denotes cough). We treat the 427 | histories and impressions as samples and split them according to diagnosis and 428 | level of consensus. 429 | status: public 430 | url: https://aclanthology.org/W07-1013/ 431 | rate_my_prof: 432 | description: reviews of lecturers from RateMyProfessor.com 433 | discipline: social science 434 | expertise: none 435 | preprocessing: Downloads sample of RateMyProfessor.com reviews from online repo. 436 | We clean the text and guess the gender of the reviewed lecturer from the first 437 | name using the gender_guesser package. Due to data availability, we consider only 438 | male and female names. To improve the quality of the classification, we remove 439 | any posts which use pronouns from the opposing sex (e.g. "him"). 440 | status: public 441 | url: https://data.mendeley.com/datasets/fvtfjyvw7d/2 442 | reddit_humor: 443 | description: jokes posted on the Reddit forum r/Jokes, a message board for sharing 444 | jokes 445 | discipline: humanities 446 | expertise: none 447 | preprocessing: Jokes are downloaded from the dev and test splits of the dataset. 448 | We clean the text and split the dataset according to whether they are labeled 449 | as funny. 450 | status: public 451 | url: https://aclanthology.org/2020.lrec-1.753/ 452 | reddit_stress: 453 | description: stress-related posts on Reddit 454 | discipline: health 455 | expertise: none 456 | preprocessing: Reddit posts are downloaded from a GitHub repository. We split the 457 | post text based on which subreddit they are posted on (related to PTSD, anxiety, 458 | or stress generally). 459 | status: public 460 | url: https://aclanthology.org/D19-6213.pdf 461 | reuters_authorship: 462 | description: articles from various Reuters authors 463 | discipline: humanities 464 | expertise: none 465 | preprocessing: Reuters articles are downloaded from the UCI repository. The articles 466 | are split according to author. 467 | status: public 468 | url: https://archive.ics.uci.edu/ml/datasets/Reuter_50_50 469 | riddles: 470 | description: common English words 471 | discipline: social science 472 | expertise: none 473 | preprocessing: The 3000 most common English words are manually copied from a website. 474 | Words with between 5 and 8 characters are kept. We create two popular riddles. 475 | First, we split words based on whether they have a duplicate character. We exclude 476 | any words with multiple "doubles" or more than 2 of any character. Second, we 477 | split words based on whether they have the letter T. 478 | status: public 479 | url: https://www.ef.edu/english-resources/english-vocabulary/top-3000-words/ 480 | scotus_cases: 481 | description: facts from cases heard by the Supreme Court of the United States (SCOTUS) 482 | discipline: social science 483 | expertise: legal knowledge 484 | preprocessing: Supreme Court cases are downloaded from a GitHub repository. We identify 485 | state/federal parties by manually defining keywords. We split based on the winning 486 | party, the identity of each party, and the type of decision. We then define several 487 | time periods and relevant political eras and split decisions accordingly. Finally, 488 | we split according to the ruling's policy area and how it changes over time. 489 | status: public 490 | url: https://paperswithcode.com/paper/justice-a-benchmark-dataset-for-supreme-court 491 | short_answer_scoring: 492 | description: short answers from students 493 | discipline: social science 494 | expertise: none 495 | preprocessing: Short answers are downloaded from a GitHub mirror of the dataset. 496 | We consider only responses to essay set 1. The two scores are averaged and binned 497 | into good (>= 2.5), medium (1.5-2.5), and bad (<1.5). 498 | status: public 499 | url: https://www.kaggle.com/c/asap-sas 500 | snli: 501 | description: a collection of sentence pairs annotated with textual entailment information 502 | from image 503 | discipline: machine learning 504 | expertise: none 505 | preprocessing: .nan 506 | status: public 507 | url: https://nlp.stanford.edu/projects/snli/ 508 | squad_v2: 509 | description: reading comprehension questions crowdsourced from Wikipedia articles 510 | discipline: machine learning 511 | expertise: none 512 | preprocessing: .nan 513 | status: public 514 | url: https://rajpurkar.github.io/SQuAD-explorer/ 515 | stock_news: 516 | description: top news headlines on Reddit, an online message board 517 | discipline: business 518 | expertise: Knowledge of stock market and financial events 519 | preprocessing: Headlines are downloaded from a GitHub mirror. We clean the text 520 | and divide the samples based on whether the DOW rose or fell that day. 521 | status: public 522 | url: https://github.com/ShravanChintha/Stock-Market-prediction-using-daily-news-headlines 523 | suicide_notes: 524 | description: posts from r/SuicideWatch and r/depression, two forums on Reddit 525 | discipline: health 526 | expertise: none 527 | preprocessing: Reddit posts are downloaded from a GitHub repository. The post title 528 | and body are combined to form the text samples. Samples are split based on whether 529 | they were posted in a suicide-related Subreddit. 530 | status: public 531 | url: https://github.com/hesamuel/goodbye_world 532 | times_india_headlines: 533 | description: headlines from Times of India news 534 | discipline: social science 535 | expertise: none 536 | preprocessing: Headlines are downloaded from a Dataverse mirror. We use the first 537 | 1000 headlines in each year as samples. 538 | status: public 539 | url: https://www.kaggle.com/datasets/therohk/india-headlines-news-dataset 540 | trial_deception: 541 | description: testimonies from witnesses in real trials 542 | discipline: social science 543 | expertise: none 544 | preprocessing: Trial testimonies are downloaded from the author's website. The testimonies 545 | are divided based on whether they are considered truthful. 546 | status: public 547 | url: https://web.eecs.umich.edu/~mihalcea/downloads.html#RealLifeDeception 548 | tweet_gender: 549 | description: random Tweets 550 | discipline: social science 551 | expertise: none 552 | preprocessing: Tweets are downloaded from a GitHub mirror. We consider only Tweets 553 | which have a 100% rating for confidence. The tweets are split into male and female 554 | gender groupings. 555 | status: accessible 556 | url: https://github.com/tranctan/Gender-Classification-based-on-Twritter-textual-data 557 | tweet_rumor: 558 | description: Tweets about various rumors 559 | discipline: social science 560 | expertise: Familiarity with specific rumors 561 | preprocessing: Twitter IDs are downloaded from Zenodo archive and collects 300 Tweets 562 | for each rumor using the Twitter API. Tweets are evenly divided into early, middle, 563 | and late thirds based on the publication time. 564 | status: private 565 | url: https://zenodo.org/record/2563864#.YzobLOzMK3I 566 | twitter_bots: 567 | description: Tweets from users identified as bots or humans 568 | discipline: social science 569 | expertise: none 570 | preprocessing: Annotated Tweets are downloaded from an online repository. We filter 571 | out non-English Tweets using the guess_langauge package and exclude any Tweets 572 | that contain eh words "fake" or "bot." For Tweets from traditional bots, social 573 | bots, and humans, we sample 20,000 of each. 574 | status: accessible 575 | url: http://mib.projects.iit.cnr.it/dataset.html 576 | twitter_misspellings: 577 | description: a collection of Tweets without emojis 578 | discipline: social science 579 | expertise: none 580 | preprocessing: Assorted Tweets are downloaded from a GitHub mirror. We manually 581 | identify eight common misspellings of words ("your", "with", "that", "going", 582 | "know", "you", "what", "the") and divide samples based on whether they contain 583 | each misspelling. 584 | status: accessible 585 | url: https://www.kaggle.com/datasets/kazanova/sentiment140 586 | twitter_sentiment140: 587 | description: random Tweets 588 | discipline: social science 589 | expertise: none 590 | preprocessing: Assorted Tweets are downloaded from a mirror and the text is used 591 | as-is for clustering. 592 | status: accessible 593 | url: https://www.kaggle.com/datasets/kazanova/sentiment140 594 | un_debates: 595 | description: speeches from debates at the United Nations 596 | discipline: social science 597 | expertise: Familiarity with political events 598 | preprocessing: Debate transcripts are downloaded from the Dataverse reproduction 599 | package. Samples are divided based on the country and year of the snippet. First, 600 | we isolate samples from Russia, China, and the United States and specify 3 time 601 | periods of interest. Next, we divide all samples by the decade. Finally, we create 602 | distributions for 19 countries of interest. 603 | status: public 604 | url: https://doi.org/10.7910/DVN/0TJX8Y 605 | unhealthy_conversations: 606 | description: expert-annotated unhealthy conversations 607 | discipline: humanities 608 | expertise: none 609 | preprocessing: Conversation transcripts are downloaded from the official GitHub 610 | repository. For each annotated attribute, we split the dataset based on whether 611 | that form of unhealthy conversation is present in the sample. 612 | status: public 613 | url: https://github.com/conversationai/unhealthy-conversations 614 | urban_dictionary: 615 | description: definitions from UrbanDictionary.com, a crowdsourced English dictionary 616 | discipline: humanities 617 | expertise: none 618 | preprocessing: Urban Dictionary entries are downloaded from Kaggle. Definitions 619 | are split into groups representing the top 1, 5, and 10 percent of definitions 620 | ranked by both upvotes and downvotes; we sample 10,000 from each and create a 621 | control distribution by randomly sampling 10,000 definitions from all entries. 622 | status: public 623 | url: https://www.kaggle.com/therohk/urban-dictionary-words-dataset 624 | wikitext: 625 | description: text snippets from Wikipedia 626 | discipline: machine learning 627 | expertise: none 628 | preprocessing: The Wikipedia snippets are loaded from hugginface. We remove any 629 | samples that are empty or start with '=' (which represent headings); samples are 630 | tokenized at the sentence level and used for clustering. 631 | status: public 632 | url: https://huggingface.co/datasets/wikitext 633 | yc_startups: 634 | description: descriptions of companies that were part of the Y Combinator startup 635 | incubator 636 | discipline: business 637 | expertise: none 638 | preprocessing: YCombinator company descriptions are downloaded from a 2022 scrape 639 | on GitHub. Only companies are long descriptions are preserved. Companies are split 640 | according to founder characteristics, year, "top company" designation, operating 641 | status, and location. 642 | status: public 643 | url: https://github.com/akshaybhalotia/yc_company_scraper 644 | -------------------------------------------------------------------------------- /scripts/README.md: -------------------------------------------------------------------------------- 1 | OpenD5 Scripts 2 | === 3 | 4 | ![](../img/diagram.png) 5 | 6 | Most of the existing intermediary files have been removed from GitHub and instead are available from this [reproduction package](https://drive.google.com/drive/u/0/folders/18N7-cwpXVtat9CocrMn_y7aI6Sbyo-Hn). Please move all folders to the repo before re-running scripts below: `clusters`, `dataset`, `downloads`, `manual`, `output`. 7 | 8 | You can generate the benchmark with accessible datasets using the `build_benchmark.sh` script. 9 | 10 | All folder names are arbitrary and can be modified in the [`parameters.py`](parameters.py) script. 11 | 12 | Be sure to create the associated conda environment from the `environment.yml` file in the main repo using: 13 | 14 | `conda env create --file environment.yml -n opend5` 15 | 16 | # Pulling Data 17 | 18 | The [`pull_data.py`](pull_data.py) script contains individual functions for each dataset in the benchmark. Each function has one of three methods for obtaining data: 19 | 1. **Download**: It downloads and leaves mostly intact a mirror of an existing dataset. Datasets are variously downloaded from GitHub repositories, Zenodo, Harvard Dataverse, sites hosted by authors, and other sources of reproduction material. 20 | 2. **Scrape**: Some datasets (e.g. [`open_review`](scrape_open_review.py) or [`admin_statemnts`](scrape_admin_statements.py)) are constructed by collecting data from an API or crawler. These datasets should be separated into separate scripts beginning with `scrape`. 21 | 3. **Manual**: For datasets without an easily accessible URL, source files are downloaded manually. The respective function should directly preprocess the downloaded dataset. 22 | 23 | By default, manually downloaded datasets should be located in the `manual` folder under a subdirectory with the dataset's name. Automatically downloaded datasets should have a copy saved to the `downloads` folder (to preserve reproducibility). All formatted datasets should be outputted to the `outputs` folder. 24 | 25 | To generate just the datasets that are accessible but not licensed for sharing, use `pull_data.py --access`. To retrieve data from Twitter, get a valid [API key](https://developer.twitter.com/en/products/twitter-api/academic-research) and add it to the relevant scripts (`scrape_twitter_rumors.py` and `scrape_blm_countermovements.py`). You should then set the `status` field of these datasets in [`datasets.yaml`](../schema/datasets.yaml) to `accessible`. 26 | 27 | # Automated Generation 28 | 29 | Some tasks in the benchmark are generated automatically at scale from large datasets. [`generate.py`](generate.py) contains helper functions for doing this. Functionalities include: 30 | - For **categorical** features, contrasting each label with every label or creating every label-to-label pair. 31 | - For **discrete** features, pairing labels *step-wise* (i.e. 1 with 2, 2 with 3, etc.) 32 | 33 | For example usage, reference the [`add_pairs.py`](add_pairs.py) script, which should also hold any of your automatically generated pairs. Datasets should be stored in tabular form in the `datasets` folder. 34 | 35 | # Clusters 36 | 37 | Several of the D5 tasks are clusters generated from large corpora. We store these unlabeled collections of text in the `unlabeled` folder. [`get_embeddings.py`](get_embeddings.py) embeds the text and [`create_cluster.py`](create_cluster.py) uses those embeddings to create K-means clusters. 38 | 39 | # Assembly 40 | 41 | [`make_benchmark.py`](make_benchmark.py) aggregates the full benchmark. If everything is set up properly running the following command (from the main repo) should generate the full benchmark in the [`benchmarks`](../benchmarks) folder: `python scripts/make_benchmark.py --full` 42 | 43 | # Utils 44 | - [`utils.py`](utils.py) contains helper functions that are used throughout. 45 | - [`test_discriminative.py`](test_discriminative.py) has weak classifiers which we use to measure problem difficulty. 46 | -------------------------------------------------------------------------------- /scripts/add_pairs.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code for automatic generation of pairs from datasets. 3 | 4 | Author: Peter Zhang 5 | """ 6 | 7 | from itertools import chain 8 | from os.path import join 9 | from typing import Dict, List 10 | 11 | import pandas as pd 12 | from generate import * 13 | from parameters import * 14 | from tqdm import tqdm 15 | 16 | 17 | def add_debate() -> List[Dict]: 18 | """ 19 | Returns a list of pairs for debate. 20 | """ 21 | 22 | df = pd.read_pickle(join(DATASET_FOLDER, "debate.pkl")) 23 | 24 | distributions = {} 25 | text_col = "abstract" 26 | 27 | df = df[df[text_col].str.split().str.len() < SNIPPET_MAXLEN] 28 | 29 | for label_col in ("year", "arg_type", "argument", "debate_camp"): 30 | distributions.update(make_distributions(df, text_col, label_col)) 31 | 32 | pairs = [] 33 | 34 | metadata = { 35 | "dataset": "debate", 36 | "generation": "which year the evidence was published", 37 | "applications": [ 38 | { 39 | "target": "how debate topics have shifted over time", 40 | "user": "a coach reflecting on the debate community", 41 | "purely_exploratory": False, 42 | } 43 | ], 44 | "pair_type": "time", 45 | "example_hypotheses": [ 46 | "relies on identity politics", 47 | "mentions the risk of extinction", 48 | ], 49 | "flip": True, 50 | } 51 | 52 | pairs.extend( 53 | generate_stepwise( 54 | df=df, 55 | label_col="year", 56 | name_template="debate_{label_col}_{label_pos}_{label_neg}", 57 | desc_template="were published in the year {keyword}", 58 | metadata=metadata, 59 | ) 60 | ) 61 | 62 | argtype2kw = { 63 | "k": "Kritiks", 64 | "aff": "Affirmatives", 65 | "case_neg": "Case Negatives", 66 | "cp": "Counterplans", 67 | "da": "Disadvantages", 68 | "a2_k": "Kritik Answers", 69 | "t": "Topicality arguments", 70 | "th": "Theory arguments", 71 | "ld": "Lincoln Douglas arguments", 72 | "politics": "Politics", 73 | "a2_cp": "Counterplan Answers", 74 | "imp": "Impact Files", 75 | "a2_da": "Disadvantage Answers", 76 | "fw": "Framework arguments", 77 | } 78 | 79 | metadata = { 80 | "dataset": "debate", 81 | "generation": "the category of argument", 82 | "applications": [ 83 | { 84 | "target": "the general topics of each category", 85 | "user": "a novice to policy debate", 86 | "purely_exploratory": True, 87 | } 88 | ], 89 | "pair_type": "subject", 90 | "example_hypotheses": [ 91 | "brings up policy issues", 92 | "critiques societal structures", 93 | ], 94 | "flip": True, 95 | } 96 | 97 | pairs.extend( 98 | generate_one_v_all( 99 | df, 100 | label_col="arg_type", 101 | name_template="debate_{label_col}_{label}_v_all", 102 | desc_template='are "{keyword}"', 103 | label2kw=argtype2kw.get, 104 | metadata=metadata, 105 | ) 106 | ) 107 | 108 | pairs.extend( 109 | generate_all_pairs( 110 | df, 111 | label_col="arg_type", 112 | name_template="debate_{label_col}_{label_pos}_{label_neg}", 113 | desc_template='are "{keyword}"', 114 | label2kw=argtype2kw.get, 115 | metadata=metadata, 116 | ) 117 | ) 118 | 119 | argument2kw = { 120 | # kritiks 121 | "ableism": "ableism", 122 | "anthro": "anthropocentrism", 123 | "afropess": "afropessimism", 124 | "antiblackness": "blackness", 125 | "baudrillard": "Baudrillard", 126 | "cap": "capitalism", 127 | "fem": "feminism", 128 | "foucault": "Foucault", 129 | "heidegger": "Heidegger", 130 | "militarism": "militarism", 131 | "neolib": "neoliberalism", 132 | "psycho": "psychoanalysis", 133 | "queerness": "queer pessimism", 134 | "security": "securitization", 135 | "settcol": "settler colonialism", 136 | # politics 137 | "midterms": "midterms", 138 | "elections": "elections", 139 | "politics": "politics", 140 | # counterplans 141 | "consult": "consultation", 142 | "states": "states taking action instead", 143 | "advantage_cp": "other ways to solve the problem", 144 | "courts": "courts taking action instead", 145 | } 146 | 147 | metadata = { 148 | "dataset": "debate", 149 | "generation": "the argument made", 150 | "applications": [ 151 | { 152 | "target": "the claims of each type of argument", 153 | "user": "a novice to policy debate", 154 | "purely_exploratory": False, 155 | } 156 | ], 157 | "pair_type": "subject", 158 | "example_hypotheses": [ 159 | "mentions solving policy problems", 160 | "critiques societal structures", 161 | ], 162 | "flip": True, 163 | } 164 | 165 | pairs.extend( 166 | generate_one_v_all( 167 | df, 168 | label_col="argument", 169 | desc_template='are arguments about "{keyword}"', 170 | name_template="debate_{label_col}_{label}_v_all", 171 | all_desc="are all other arguments", 172 | label2kw=argument2kw.get, 173 | metadata=metadata, 174 | ) 175 | ) 176 | 177 | pairs.extend( 178 | generate_all_pairs( 179 | df, 180 | label_col="argument", 181 | name_template="debate_{label_col}_{label_pos}_{label_neg}", 182 | desc_template='are arguments about "{keyword}"', 183 | label2kw=argument2kw.get, 184 | metadata=metadata, 185 | ) 186 | ) 187 | 188 | camp2kw = { 189 | "gdi": "Gonzaga (GDI)", 190 | "ddi": "Dartmouth DDIx", 191 | "nhsi": "Northwestern (NHSI)", 192 | "cdni": "Berkeley (CNDI)", 193 | "wyoming": "Wyoming", 194 | "gds": "Georgetown (GDS)", 195 | "utnif": "Texas (UTNIF)", 196 | "msdi": "Missouri State (MSDI)", 197 | "jdi": "Kansas (JDI)", 198 | "mich_7week": "Michigan (7-week)", 199 | "scdi": "Sun Country (SCDI)", 200 | "unt": "North Texas (UNT)", 201 | "samford": "Samford", 202 | "endi": "Emory (ENDI)", 203 | "hss": "Hoya-Spartan Scholars", 204 | "sdi": "Michigan State (SDI)", 205 | "mich_classic": "Michigan (Classic)", 206 | "mndi": "Michigan (MNDI)", 207 | "rks": "Wake Forest (RKS)", 208 | "georgia": "Georgia", 209 | "harvard": "Harvard", 210 | "wsdi": "Weber State (WSDI)", 211 | "utd": "UT Dallas (UTD)", 212 | "naudl": "NAUDL", 213 | "baylor": "Baylor", 214 | "mgc": "Mean Green Comet", 215 | "tdi": "The Debate Intensive", 216 | "nsd": "National Symposium for Debate", 217 | } 218 | 219 | metadata = { 220 | "dataset": "debate", 221 | "generation": "the debate camp that published the evidence", 222 | "applications": [ 223 | { 224 | "target": "what specific topics each debate camp focuses on", 225 | "user": "a debater deciding which camp to go to", 226 | "purely_exploratory": False, 227 | } 228 | ], 229 | "pair_type": "author", 230 | "example_hypotheses": [ 231 | "mentions solving policy problems", 232 | "critiques societal structures", 233 | ], 234 | "flip": True, 235 | } 236 | 237 | pairs.extend( 238 | generate_one_v_all( 239 | df, 240 | label_col="debate_camp", 241 | desc_template="are pieces of evidence compiled by {keyword}, a debate camp", 242 | all_desc="are pieces of evidence from every other debate camp", 243 | name_template="debate_{label_col}_{label}_v_all", 244 | label2kw=camp2kw.get, 245 | metadata=metadata, 246 | ) 247 | ) 248 | 249 | pairs.extend( 250 | generate_all_pairs( 251 | df, 252 | label_col="debate_camp", 253 | desc_template="are pieces of evidence compiled by {keyword}, a debate camp", 254 | label2kw=camp2kw.get, 255 | name_template="debate_{label_col}_{label_pos}_{label_neg}", 256 | metadata=metadata, 257 | ) 258 | ) 259 | 260 | for pair in tqdm(pairs): 261 | dists_pos = list(chain(*[distributions[c] for c in pair["pos_class"]])) 262 | dists_neg = list(chain(*[distributions[c] for c in pair["neg_class"]])) 263 | pair["pos_samples"] = dists_pos 264 | pair["neg_samples"] = dists_neg 265 | del pair["pos_class"] 266 | del pair["neg_class"] 267 | pair["hash"] = hash(tuple(dists_pos) + tuple(dists_neg)) 268 | 269 | return pairs 270 | 271 | 272 | def add_amazon_reviews() -> List[Dict]: 273 | """ 274 | Returns a list of pairs for Amazon reviews. 275 | """ 276 | 277 | df = pd.read_pickle(join(DATASET_FOLDER, "amazon_reviews.pkl")) 278 | 279 | distributions = {} 280 | text_col = "text" 281 | 282 | df = df[df[text_col].str.split().str.len() < SNIPPET_MAXLEN] 283 | 284 | for label_col in ("year", "product_category"): 285 | distributions.update(make_distributions(df, text_col, label_col)) 286 | 287 | distributions.update( 288 | make_distributions(df, text_col, label_cols=["product_category", "stars"]) 289 | ) 290 | 291 | pairs = [] 292 | 293 | metadata = { 294 | "dataset": "amazon_reviews", 295 | "generation": "how many stars the review gave", 296 | "applications": [ 297 | { 298 | "target": "which specific aspects users dislike, such as the price, features, or performance", 299 | "user": "a seller of various products on Amazon", 300 | "purely_exploratory": False, 301 | } 302 | ], 303 | "pair_type": "sentiment", 304 | "example_hypotheses": [ 305 | "mentions missing a critical piece", 306 | "complains about the lack of instructions", 307 | ], 308 | "flip": False, 309 | } 310 | 311 | category2kw = { 312 | "amazon_fashion": "fashion items", 313 | "beauty": "beauty products", 314 | "appliances": "appliances", 315 | "arts_crafts": "arts, crafts, and sewing products", 316 | "automotive": "automotive", 317 | "cds": "CDs", 318 | "cell_phones": "cell phones and accessories", 319 | "digital_music": "digital music", 320 | "gift_cards": "gift cards", 321 | "grocery": "grocery and gourmet food", 322 | "industrial_scientific": "industrial and scientific products", 323 | "luxury_beauty": "luxury beauty products", 324 | "magazines": "magazines", 325 | "music_instruments": "music instruments", 326 | "office": "office products", 327 | "patio": "patio products", 328 | "pantry": "pantry goods", 329 | "software": "software", 330 | "video_games": "video games", 331 | } 332 | 333 | desc_template = "are reviews of {prodkw} on Amazon giving {starkw} star" 334 | star_comp = [ 335 | ((1,), (5,)), 336 | ((1,), (3,)), 337 | ((2,), (4,)), 338 | ((4,), (5,)), 339 | ] 340 | 341 | def stars2keyword(stars: List[int]): 342 | assert all(s in [1, 2, 3, 4, 5] for s in stars) 343 | if len(stars) == 1: 344 | return str(stars[0]) 345 | stars = sorted(stars) 346 | s = "" 347 | while len(stars) > 1: 348 | s += stars.pop(0) + ", " 349 | return s + "or " + stars[0] 350 | 351 | for product_category in df["product_category"].unique(): 352 | prodkw = category2kw[product_category] 353 | for pos_stars, neg_stars in star_comp: 354 | pos_kw = stars2keyword(pos_stars) 355 | neg_kw = stars2keyword(neg_stars) 356 | pos_desc = desc_template.format(prodkw=prodkw, starkw=pos_kw) 357 | neg_desc = desc_template.format(prodkw=prodkw, starkw=neg_kw) 358 | pair_name = f'amazon_reviews_{product_category}_stars_{"".join(map(str, pos_stars))}_{"".join(map(str, neg_stars))}' 359 | if pos_stars != (1,): 360 | pos_desc += "s" 361 | if neg_stars != (1,): 362 | neg_desc += "s" 363 | 364 | metadata["dataset_description"] = "Amazon reviews of {prodkw}" 365 | pair = make_pair( 366 | pair_name=pair_name, 367 | label_col="stars", 368 | labels_pos=pos_stars, 369 | labels_neg=neg_stars, 370 | desc_pos=pos_desc, 371 | desc_neg=neg_desc, 372 | prefix=f"product_category_{product_category}_", 373 | metadata=metadata, 374 | ) 375 | pairs.append(pair) 376 | 377 | for pair in tqdm(pairs): 378 | dists_pos = list(chain(*[distributions[c] for c in pair["pos_class"]])) 379 | dists_neg = list(chain(*[distributions[c] for c in pair["neg_class"]])) 380 | pair["pos_samples"] = dists_pos 381 | pair["neg_samples"] = dists_neg 382 | pair["hash"] = hash(tuple(dists_pos) + tuple(dists_neg)) 383 | del pair["pos_class"] 384 | del pair["neg_class"] 385 | 386 | return pairs 387 | -------------------------------------------------------------------------------- /scripts/build_benchmark.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | eval "$(conda shell.bash hook)" 3 | conda env create --file environment.yml -n opend5 4 | conda activate opend5 5 | python scripts/pull_data.py --access 6 | python scripts/make_benchmark.py --full -------------------------------------------------------------------------------- /scripts/create_cluster.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code for creating clusters from embeddings. 3 | 4 | Author: Peter Zhang, Ruiqi Zhong 5 | """ 6 | 7 | import glob 8 | import json 9 | import os 10 | from argparse import ArgumentParser 11 | from collections import defaultdict 12 | from datetime import datetime 13 | from typing import Tuple 14 | 15 | import numpy as np 16 | import sklearn.cluster 17 | import sklearn.decomposition 18 | import sklearn.mixture 19 | import tqdm 20 | from sklearn.cluster import KMeans 21 | from sklearn.metrics.pairwise import euclidean_distances 22 | 23 | 24 | def load_data(embed_dir: str, subset_size: int) -> Tuple[np.array, np.array]: 25 | """ 26 | Accepts a directory with embeddings and a sample size to use. 27 | """ 28 | 29 | f_prefixes = sorted( 30 | [f.split(".")[0] for f in os.listdir(embed_dir) if f.endswith(".npy")], 31 | key=lambda x: int(x), 32 | ) 33 | all_embeddings, all_texts = [], [] 34 | for f in tqdm.tqdm(f_prefixes): 35 | new_embeddings = np.load(os.path.join(embed_dir, f + ".npy")) 36 | if len(new_embeddings.shape) == 2: 37 | all_embeddings.extend(new_embeddings) 38 | all_texts.extend(json.load(open(os.path.join(embed_dir, f + ".json")))) 39 | if len(all_embeddings) >= subset_size: 40 | break 41 | 42 | all_embeddings = np.array(all_embeddings)[:subset_size] 43 | all_texts = all_texts[:subset_size] 44 | 45 | return all_embeddings, all_texts 46 | 47 | 48 | def make_clusters(all_embeddings, first_pc, last_pc, cluster_method, k): 49 | 50 | # loading the embeddings and texts 51 | 52 | print(f"finished loading {len(all_embeddings)} embeddings") 53 | 54 | # first run PCA 55 | pca = sklearn.decomposition.PCA(n_components=1 + last_pc) 56 | 57 | # fit the PCA model to the embeddings 58 | all_embs = pca.fit_transform(all_embeddings) 59 | all_embs = all_embs[:, first_pc : last_pc + 1] 60 | print("finished PCA") 61 | 62 | # GMM clustering 63 | # defining the clustering model 64 | if cluster_method == "gmm": 65 | cluster = sklearn.mixture.GaussianMixture( 66 | n_components=k, covariance_type="full" 67 | ) 68 | elif cluster_method == "kmeans": 69 | cluster = KMeans(n_clusters=k) 70 | 71 | cluster.fit(all_embs) 72 | if cluster_method == "gmm": 73 | centers = cluster.means_ 74 | elif cluster_method == "kmeans": 75 | centers = cluster.cluster_centers_ 76 | 77 | print("finished clustering") 78 | cluster_idxes = cluster.predict(all_embs) 79 | 80 | print("finished predicting probabilities") 81 | center_pairwise_distances = euclidean_distances(centers, centers) 82 | 83 | return cluster_idxes, center_pairwise_distances 84 | 85 | 86 | def save_results(save_dir, cluster_idxes, all_texts, center_pairwise_distances): 87 | """ 88 | Save the results of the clustering. 89 | """ 90 | 91 | # saving the results 92 | if not os.path.exists(save_dir): 93 | os.makedirs(save_dir) 94 | clusters = defaultdict(list) 95 | for cluster, text in zip(cluster_idxes, all_texts): 96 | clusters[int(cluster)].append(text) 97 | json.dump(clusters, open(os.path.join(save_dir, "clusters.json"), "w")) 98 | l2_distances = dict(enumerate(map(list, center_pairwise_distances.astype(float)))) 99 | json.dump(l2_distances, open(os.path.join(save_dir, "l2_distance.json"), "w")) 100 | 101 | 102 | def main(): 103 | 104 | parser = ArgumentParser() 105 | parser.add_argument("--make_all", action="store_true") 106 | parser.add_argument("--dataset", type=str) 107 | parser.add_argument("--first_pc", type=int, default=1) 108 | parser.add_argument("--last_pc", type=int, default=30) 109 | parser.add_argument("--subset_size", type=int, default=100000) 110 | parser.add_argument("--sqrt_size", action="store_true") 111 | parser.add_argument("--k", type=int, default=128) 112 | parser.add_argument("--cluster_method", type=str, default="kmeans") 113 | 114 | args = parser.parse_args() 115 | make_all = args.make_all 116 | dataset = args.dataset 117 | first_pc = args.first_pc 118 | last_pc = args.last_pc 119 | subset_size = args.subset_size 120 | sqrt_size = args.sqrt_size 121 | k = args.k 122 | cluster_method = args.cluster_method 123 | 124 | if make_all: 125 | datasets = glob.glob("results/*_embeddings") 126 | else: 127 | datasets = [dataset] 128 | 129 | for dataset in datasets: 130 | 131 | embed_dir = f"results/{dataset}_embeddings" 132 | 133 | all_embeddings, all_texts = load_data(embed_dir, subset_size) 134 | 135 | if sqrt_size: 136 | k = int(np.sqrt(len(all_embeddings)) / 2) 137 | print(f"using sqrt size for dataset {dataset}, k={k}") 138 | 139 | cluster_idxes, center_pairwise_distances = make_clusters( 140 | all_embeddings, first_pc, last_pc, cluster_method, k 141 | ) 142 | 143 | time = datetime.now().strftime("%Y%d%m_%H%M%S") 144 | 145 | if sqrt_size: 146 | save_dir = f"results/{dataset}_{time}_clusters_sqrtsize" 147 | else: 148 | save_dir = f"results/{dataset}_{time}_clusters_{k}" 149 | 150 | save_results( 151 | save_dir, 152 | cluster_idxes, 153 | all_texts, 154 | center_pairwise_distances, 155 | ) 156 | 157 | 158 | if __name__ == "__main__": 159 | main() 160 | -------------------------------------------------------------------------------- /scripts/generate.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for generating distributions from datasets. 3 | 4 | Author: Peter Zhang 5 | """ 6 | 7 | from itertools import combinations, product 8 | from typing import Callable, Dict, List 9 | 10 | import pandas as pd 11 | 12 | 13 | def make_distributions( 14 | df: pd.DataFrame, 15 | text_col: str, 16 | label_col: str = None, 17 | label_cols: List[str] = None, 18 | dropna: bool = True, 19 | ) -> Dict[str, List[str]]: 20 | """ 21 | Accepts a Dataframe, text column, and label column 22 | and creates a dictionary mapping each value of the 23 | label to a list of strings from the text column. 24 | """ 25 | assert label_col or label_cols, "Must include one or more label columns" 26 | 27 | label_cols = label_cols or [label_col] 28 | 29 | distributions = {} 30 | 31 | all_labels = [] 32 | for col in label_cols: 33 | labels = df[col] 34 | if dropna: 35 | labels = labels.dropna() 36 | all_labels.append(labels.unique()) 37 | 38 | for levels in product(*all_labels): 39 | conditions = df[label_cols[0]] == levels[0] 40 | dist_name = f"{label_cols[0]}_{levels[0]}" 41 | for label_col, level in zip(label_cols[1:], levels[1:]): 42 | conditions = conditions & (df[label_col] == level) 43 | dist_name += f"_{label_col}_{level}" 44 | 45 | texts = df[conditions][text_col].tolist() 46 | distributions[dist_name] = texts 47 | 48 | return distributions 49 | 50 | 51 | def make_pair( 52 | pair_name: str, 53 | label_col: str, 54 | labels_pos: str, 55 | labels_neg: str, 56 | desc_pos: str, 57 | desc_neg: str, 58 | metadata: Dict = {}, 59 | prefix: str = "", 60 | ) -> List[Dict]: 61 | """ 62 | Accepts the label columns name, two labels, a template 63 | and a label2kw function and constructs a pair's 64 | metadata. 65 | """ 66 | 67 | pair = { 68 | "pair": pair_name, 69 | "pos_desc": desc_pos, 70 | "pos_class": [f"{prefix}{label_col}_{label_pos}" for label_pos in labels_pos], 71 | "neg_desc": desc_neg, 72 | "neg_class": [f"{prefix}{label_col}_{label_neg}" for label_neg in labels_neg], 73 | } 74 | pair.update(metadata) 75 | 76 | return pair 77 | 78 | 79 | def generate_stepwise( 80 | df: pd.DataFrame, 81 | label_col: str, 82 | desc_template: str, 83 | name_template: str, 84 | label2kw: Callable = lambda x: x, 85 | metadata: Dict = {}, 86 | ) -> List: 87 | """ 88 | Accepts a dataframe, text column, and an ordered label column 89 | and returns a 2-tuple of a distributions and pairs by considering 90 | pairs of labels in order. 91 | """ 92 | 93 | unique_labels = df[label_col].dropna().unique() 94 | unique_labels = sorted(unique_labels) 95 | 96 | pairs = [] 97 | for label_pos, label_neg in zip(unique_labels[:-1], unique_labels[1:]): 98 | kw_pos = label2kw(label_pos) 99 | kw_neg = label2kw(label_neg) 100 | pair_name = name_template.format( 101 | label_col=label_col, label_pos=label_pos, label_neg=label_neg 102 | ) 103 | desc_pos = desc_template.format(keyword=kw_pos) 104 | desc_neg = desc_template.format(keyword=kw_neg) 105 | pair = make_pair( 106 | pair_name, label_col, [label_pos], [label_neg], desc_pos, desc_neg, metadata 107 | ) 108 | pairs.append(pair) 109 | 110 | return pairs 111 | 112 | 113 | def generate_all_pairs( 114 | df: pd.DataFrame, 115 | label_col: str, 116 | desc_template: str, 117 | name_template: str, 118 | label2kw: Callable = lambda x: x, 119 | metadata: Dict = {}, 120 | ) -> List[Dict]: 121 | """ 122 | Accepts a dataframe, text column, and a category label column 123 | and returns a list of pair metadata comparing all pairs. 124 | """ 125 | assert df[label_col].dtype == "category" 126 | 127 | pairs = [] 128 | labels = df[label_col].dropna().unique() 129 | for label_pos, label_neg in combinations(labels, 2): 130 | kw_pos = label2kw(label_pos) 131 | kw_neg = label2kw(label_neg) 132 | pair_name = name_template.format( 133 | label_col=label_col, label_pos=label_pos, label_neg=label_neg 134 | ) 135 | desc_pos = desc_template.format(keyword=kw_pos) 136 | desc_neg = desc_template.format(keyword=kw_neg) 137 | pair = make_pair( 138 | pair_name, label_col, [label_pos], [label_neg], desc_pos, desc_neg, metadata 139 | ) 140 | pairs.append(pair) 141 | 142 | return pairs 143 | 144 | 145 | def generate_one_v_all( 146 | df: pd.DataFrame, 147 | label_col: str, 148 | name_template: str, 149 | desc_template: str, 150 | label2kw: Callable = lambda x: x, 151 | metadata: Dict = {}, 152 | dropna: bool = True, 153 | all_desc: str = None, 154 | ) -> List[Dict]: 155 | """ 156 | Accepts a dataframe, text column, and a category label column 157 | and returns a list of pair metadata comparing each pair 158 | with every other. 159 | """ 160 | 161 | pairs = [] 162 | labels = df[label_col].unique() 163 | if dropna: 164 | labels = labels.dropna() 165 | labels = labels.tolist() 166 | 167 | for label_pos in labels: 168 | labels_neg = labels.copy() 169 | labels_neg.remove(label_pos) 170 | pair_name = name_template.format(label_col=label_col, label=label_pos) 171 | kw_pos = label2kw(label_pos) 172 | desc_pos = desc_template.format(keyword=kw_pos) 173 | if all_desc: 174 | desc_neg = all_desc 175 | else: 176 | desc_neg = desc_template.format(keyword=f"not {kw_pos}") 177 | 178 | pair = make_pair( 179 | pair_name, label_col, [label_pos], labels_neg, desc_pos, desc_neg, metadata 180 | ) 181 | pairs.append(pair) 182 | 183 | return pairs 184 | -------------------------------------------------------------------------------- /scripts/get_embeddings.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utillities for creating embeddings. 3 | 4 | Author: Ruiqi Zhong 5 | """ 6 | 7 | import glob 8 | import json 9 | import os 10 | from argparse import ArgumentParser 11 | from functools import partial 12 | from os.path import join 13 | from typing import List 14 | 15 | import numpy as np 16 | import torch 17 | import tqdm 18 | from transformers import (BertModel, BertTokenizer, RobertaModel, 19 | RobertaTokenizer, T5EncoderModel, T5Tokenizer) 20 | 21 | device = "cuda" if torch.cuda.is_available() else "cpu" 22 | BSIZE = 32 23 | SAVE_EVERY = 10000 24 | DEFAULT_SAMPLES = 100000 25 | DATA_FOLDER = "unlabeled" 26 | 27 | 28 | def roberta_embed(model_tokenizer, sentences: List[str]): 29 | """ 30 | Embeds a list of sentences using Roberta. 31 | """ 32 | model, tokenizer = model_tokenizer 33 | model.eval() 34 | 35 | with torch.no_grad(): 36 | inputs = tokenizer( 37 | sentences, return_tensors="pt", padding=True, truncation=True 38 | ).to(device) 39 | outputs = model(**inputs).pooler_output 40 | return outputs.cpu().numpy() 41 | 42 | 43 | def t5_embed(model_tokenizer, sentences: List[str]): 44 | """ 45 | Embeds a list of sentences using T5. 46 | """ 47 | 48 | model, tokenizer = model_tokenizer 49 | model.eval() 50 | 51 | with torch.no_grad(): 52 | inputs = tokenizer( 53 | sentences, return_tensors="pt", padding=True, truncation=True 54 | ).to(device) 55 | outputs = torch.mean(model(**inputs).last_hidden_state, dim=1) 56 | return outputs.cpu().numpy() 57 | 58 | 59 | def bert_embed(model_tokenizer, sentences: List[str]): 60 | """ 61 | Embeds a list of sentences using BERT. 62 | """ 63 | 64 | model, tokenizer = model_tokenizer 65 | model.eval() 66 | 67 | with torch.no_grad(): 68 | inputs = tokenizer( 69 | sentences, return_tensors="pt", padding=True, truncation=True 70 | ).to(device) 71 | outputs = model(**inputs).pooler_output 72 | return outputs.cpu().numpy() 73 | 74 | 75 | def embed_sentences( 76 | embed_func, 77 | sentences: List[str], 78 | samples: int, 79 | bsize: int = BSIZE, 80 | save_dir: str = None, 81 | ): 82 | """ 83 | Embeds a list of sentences using a given embedding function. 84 | """ 85 | 86 | embeddings, texts = [], [] 87 | save_threshold = [i * SAVE_EVERY for i in range(1, samples // SAVE_EVERY + 2)] 88 | for i in tqdm.trange(0, len(sentences), bsize): 89 | sentence_batch = sentences[i : i + bsize] 90 | embeddings.extend(embed_func(sentence_batch)) 91 | texts.extend(sentence_batch) 92 | finished_count = i + bsize 93 | if save_dir is not None and finished_count > save_threshold[0]: 94 | embeddings = np.array(embeddings) 95 | np.save(os.path.join(save_dir, f"{finished_count}.npy"), embeddings) 96 | json.dump( 97 | texts, open(os.path.join(save_dir, f"{finished_count}.json"), "w") 98 | ) 99 | save_threshold.pop(0) 100 | embeddings = [] 101 | texts = [] 102 | if len(embeddings) > 0: 103 | np.save( 104 | os.path.join(save_dir, f"{finished_count}.npy"), 105 | np.concatenate(embeddings, axis=0), 106 | ) 107 | json.dump(texts, open(os.path.join(save_dir, f"{finished_count}.json"), "w")) 108 | 109 | 110 | if __name__ == "__main__": 111 | 112 | parser = ArgumentParser() 113 | parser.add_argument("--get_all", action="store_true") 114 | parser.add_argument("--dataset", type=str) 115 | parser.add_argument("--model_name", type=str, default="roberta-base") 116 | parser.add_argument("--samples", type=int, default=DEFAULT_SAMPLES) 117 | 118 | args = parser.parse_args() 119 | 120 | get_all = args.get_all 121 | model_name = args.model_name 122 | samples = args.samples 123 | dataset = args.dataset 124 | 125 | if "roberta" in model_name: 126 | model = RobertaModel.from_pretrained(model_name).to(device) 127 | tokenizer = RobertaTokenizer.from_pretrained(model_name) 128 | model_tokenizer = (model, tokenizer) 129 | embed_func = partial(roberta_embed, model_tokenizer) 130 | elif "t5" in model_name: 131 | model = T5EncoderModel.from_pretrained(model_name).to(device) 132 | tokenizer = T5Tokenizer.from_pretrained(model_name) 133 | model_tokenizer = (model, tokenizer) 134 | embed_func = partial(t5_embed, model_tokenizer) 135 | elif "bert" in model_name: 136 | model = BertModel.from_pretrained(model_name).to(device) 137 | tokenizer = BertTokenizer.from_pretrained(model_name) 138 | model_tokenizer = (model, tokenizer) 139 | embed_func = partial(bert_embed, model_tokenizer) 140 | 141 | if get_all: 142 | files = glob.glob("unlabeled/*") 143 | datasets = [file[10:-5] for file in files] 144 | else: 145 | datasets = [dataset] 146 | 147 | for dataset in datasets: 148 | 149 | print(f"embedding {dataset}") 150 | save_dir = f"results/{dataset}_embeddings" 151 | 152 | os.makedirs(save_dir, exist_ok=True) 153 | 154 | filename = join(DATA_FOLDER, f"{dataset}.json") 155 | data = json.load(open(filename, "r"))[:samples] 156 | 157 | embeddings = embed_sentences(embed_func, data, samples, save_dir=save_dir) 158 | -------------------------------------------------------------------------------- /scripts/make_benchmark.py: -------------------------------------------------------------------------------- 1 | """ 2 | Constructs and cleans the benchmark from components. 3 | 4 | Author: Peter Zhang 5 | """ 6 | 7 | import argparse 8 | import json 9 | import pickle as pkl 10 | import random 11 | from collections import Counter, defaultdict 12 | from copy import deepcopy 13 | from itertools import chain, combinations 14 | from os.path import join 15 | from typing import List 16 | 17 | import numpy as np 18 | import yaml 19 | from add_pairs import * 20 | from parameters import * 21 | from test_discriminative import (discriminated_w_mnb, 22 | discriminated_w_simple_feature) 23 | from utils import remove_empty, split_truncate_word 24 | from yaml.loader import SafeLoader 25 | 26 | with open(join(SCHEMA_FOLDER, DATASET_SCHEMA), "r") as f: 27 | dataset_metadata = yaml.load(f, Loader=SafeLoader) 28 | 29 | 30 | def filter_pairs_for_desc( 31 | pairs: List[dict], 32 | threshold: int = DISC_THRESH, 33 | discriminator: str = "auc_roc_mnb", 34 | ) -> List[dict]: 35 | """ 36 | Removes pairs from a list that are too easily 37 | discriminated as judged from a threshold and 38 | discriminator. 39 | """ 40 | 41 | filtered_pairs = [] 42 | 43 | for pair in pairs: 44 | disc = pair[discriminator] 45 | if disc < threshold: 46 | pairs.append(construct_pair(**pair)) 47 | 48 | return filtered_pairs 49 | 50 | 51 | def construct_pair( 52 | include_stats: bool = True, 53 | include_roc: bool = True, 54 | truncate: bool = True, 55 | **pair_data, 56 | ) -> dict: 57 | """ 58 | Constructs a pair from pair data and optionally adds metadata. 59 | """ 60 | missing_fields = [field for field in PAIR_FIELDS if field not in pair_data.keys()] 61 | # check for all required fields 62 | assert not missing_fields, f'Fields ({", ".join(missing_fields)}) missing' 63 | 64 | for i, application in enumerate(pair_data["applications"]): 65 | missing_fields = [ 66 | field for field in APPLICATION_FIELDS if field not in application.keys() 67 | ] 68 | assert ( 69 | not missing_fields 70 | ), f'Fields ({", ".join(missing_fields)}) missing from application {i}' 71 | 72 | dataset = pair_data["dataset"] 73 | pos_samples = remove_empty(pair_data["pos_samples"]) 74 | neg_samples = remove_empty(pair_data["neg_samples"]) 75 | 76 | if truncate: 77 | pos_weighted_samples = split_truncate_word(pos_samples) 78 | neg_weighted_samples = split_truncate_word(neg_samples) 79 | 80 | pos_weights = [s[1] for s in pos_weighted_samples] 81 | neg_weights = [s[1] for s in neg_weighted_samples] 82 | pos_samples = [s[0] for s in pos_weighted_samples] 83 | neg_samples = [s[0] for s in neg_weighted_samples] 84 | 85 | pair_data["pos_weights"] = pos_weights 86 | pair_data["neg_weights"] = neg_weights 87 | pair_data["pos_samples"] = pos_samples 88 | pair_data["neg_samples"] = neg_samples 89 | else: 90 | pair_data["pos_weights"] = [1] * len(pos_samples) 91 | pair_data["neg_weights"] = [1] * len(neg_samples) 92 | 93 | pair_data["discipline"] = dataset_metadata[dataset]["discipline"] 94 | if not ("dataset_description" in pair_data and pair_data["dataset_description"]): 95 | pair_data["dataset_description"] = dataset_metadata[dataset]["description"] 96 | else: 97 | desc = pair_data["dataset_description"].strip() 98 | pair_data["dataset_description"] = desc[0].lower() + desc[1:] 99 | pair_data["status"] = dataset_metadata[dataset]["status"] 100 | pair_data["expertise"] = dataset_metadata[dataset]["expertise"] 101 | pair_data["hash"] = hash(tuple(pos_samples + neg_samples)) 102 | 103 | if include_stats: 104 | pair_data["avg_chars"] = float( 105 | np.mean(list(map(len, pos_samples + neg_samples))) 106 | ) 107 | pair_data["avg_words"] = float( 108 | np.mean(list(map(lambda s: len(s.split(" ")), pos_samples + neg_samples))) 109 | ) 110 | 111 | if include_roc: 112 | auc_roc_simple = discriminated_w_simple_feature(pos_samples, neg_samples) 113 | auc_roc_mnb = discriminated_w_mnb(pos_samples, neg_samples) 114 | pair_data["auc_roc_simple"] = auc_roc_simple 115 | pair_data["auc_roc_mnb"] = auc_roc_mnb 116 | 117 | return pair_data 118 | 119 | 120 | def generate_pair(**pair_data) -> dict: 121 | """ 122 | Creates pair of distributions and relevant metadata. 123 | """ 124 | 125 | dataset = pair_data["dataset"] 126 | pos_class = pair_data["pos_class"] 127 | neg_class = pair_data["neg_class"] 128 | del pair_data["pos_class"] 129 | del pair_data["neg_class"] 130 | 131 | path = f"{OUTPUT_FOLDER}/{dataset}.json" 132 | output = json.load(open(path, "r")) 133 | pos_dists = [output["data"][dist] for dist in pos_class] 134 | neg_dists = [output["data"][dist] for dist in neg_class] 135 | pos_samples = list(chain.from_iterable(pos_dists)) 136 | neg_samples = list(chain.from_iterable(neg_dists)) 137 | 138 | return construct_pair(pos_samples=pos_samples, neg_samples=neg_samples, **pair_data) 139 | 140 | 141 | def add_spurious_data(path: str) -> List[dict]: 142 | """ 143 | Adds externally generated spurious pairs. 144 | """ 145 | 146 | data = json.load(open(path)) 147 | 148 | pairs = [] 149 | 150 | for pair in data: 151 | pairs.append(construct_pair(**pair)) 152 | 153 | return pairs 154 | 155 | 156 | def add_cluster_data( 157 | sample_path: str, 158 | distance_path: str, 159 | dataset: str, 160 | ) -> List[dict]: 161 | """ 162 | Adds externally generated cluster data. 163 | """ 164 | 165 | pairs = [] 166 | generation = 'which automatically generated "cluster" the snippet is from' 167 | user = "a data scientist performing unsupervised clustering" 168 | target = "what each cluster represents" 169 | purely_exploratory = True 170 | 171 | applications = [ 172 | { 173 | "user": user, 174 | "target": target, 175 | "purely_exploratory": purely_exploratory, 176 | } 177 | ] 178 | 179 | sample_data = json.load(open(sample_path, "r")) 180 | distance_data = json.load(open(distance_path, "r")) 181 | 182 | print("1 v all") 183 | 184 | # 1 v all framing 185 | flip = False 186 | 187 | for pos_id, pos_samples in tqdm(sample_data.items()): 188 | 189 | pair_type = "cluster_versus_all" 190 | 191 | neg_ids = [id for id in sample_data if id != pos_id] 192 | neg_samples = list(chain(*[sample_data[id] for id in neg_ids])) 193 | 194 | pair = f"{dataset}_cluster_{pos_id}_v_all" 195 | pos_desc = "are from a particular cluster" 196 | neg_desc = "are from the rest of the cluster" 197 | 198 | random.seed(0) 199 | neg_samples = random.sample(neg_samples, k=len(pos_samples)) 200 | 201 | p = construct_pair( 202 | pair=pair, 203 | dataset=dataset, 204 | generation=generation, 205 | applications=applications, 206 | example_hypotheses=[], 207 | pos_desc=pos_desc, 208 | neg_desc=neg_desc, 209 | pos_samples=pos_samples, 210 | neg_samples=neg_samples, 211 | pair_type=pair_type, 212 | flip=flip, 213 | note="", 214 | include_roc=True, 215 | ) 216 | 217 | pairs.append(p) 218 | 219 | clusters = sample_data.keys() # pairs of clusters, 3% closest distance 220 | 221 | distances = list(chain(*distance_data.values())) 222 | distances = [d for d in distances if d] # remove 0 223 | 224 | def get_cutoff(p): 225 | return np.percentile(distances, p) # returns the percentile distance 226 | 227 | CLOSEST = 3 228 | CLOSE = 8 229 | SOMEWHAT_CLOSE = 15 230 | 231 | def get_bucket(d): 232 | """Decides how close a given cluster is.""" 233 | if d < get_cutoff(CLOSEST): 234 | return "very close" 235 | if d < get_cutoff(CLOSE): 236 | return "close" 237 | if d < get_cutoff(SOMEWHAT_CLOSE): 238 | return "somewhat close" 239 | return None 240 | 241 | # 1v1 by closeness 242 | flip = True 243 | 244 | bucketed_combos = defaultdict(list) 245 | for cluster1, cluster2 in combinations(clusters, 2): 246 | distance = distance_data[cluster1][int(cluster2)] 247 | bucketed_combos[get_bucket(distance)].append((cluster1, cluster2)) 248 | 249 | SAMPLE_SIZE = 20 250 | 251 | print("closeness") 252 | for closeness in ("very close", "close", "somewhat close"): 253 | 254 | np.random.seed(1) 255 | np.random.shuffle(bucketed_combos[closeness]) 256 | samples = bucketed_combos[closeness][:SAMPLE_SIZE] 257 | 258 | for cluster1, cluster2 in tqdm(samples): 259 | 260 | pair_type = f'{closeness.replace(" ","_")}_clusters' 261 | 262 | pos_samples = sample_data[cluster1] 263 | neg_samples = sample_data[cluster2] 264 | 265 | pair = f"{dataset}_cluster_{cluster1}_v_{cluster2}" 266 | 267 | pos_desc = "are from one cluster" 268 | neg_desc = f"are from a {closeness} cluster" 269 | 270 | p = construct_pair( 271 | pair=pair, 272 | dataset=dataset, 273 | generation=generation, 274 | applications=applications, 275 | example_hypotheses=[], 276 | pos_desc=pos_desc, 277 | neg_desc=neg_desc, 278 | pos_samples=pos_samples, 279 | neg_samples=neg_samples, 280 | pair_type=pair_type, 281 | flip=flip, 282 | note="", 283 | include_roc=True, 284 | ) 285 | 286 | pairs.append(p) 287 | 288 | return pairs 289 | 290 | 291 | def post_processing(benchmark, max_size=55, train_size=0.5, remove_private=True): 292 | """ 293 | Flatten applications, resample overrepresented samples, 294 | and create train-test splits. 295 | """ 296 | 297 | random.seed(0) 298 | 299 | # flatten applications 300 | flattened_benchmark = [] 301 | for pair in benchmark: 302 | if not "status" in pair: 303 | print(pair["pair"]) 304 | if remove_private and (pair["status"] == "private"): 305 | continue 306 | for application in pair["applications"]: 307 | new_pair = deepcopy(pair) 308 | del new_pair["applications"] 309 | new_pair["application"] = application 310 | flattened_benchmark.append(new_pair) 311 | benchmark = flattened_benchmark 312 | 313 | # downsample 314 | by_dataset = defaultdict(list) 315 | for app in benchmark: 316 | by_dataset[app["dataset"]].append(app["hash"]) 317 | keep_hashes = set() 318 | for dataset in by_dataset: 319 | hashes = by_dataset[dataset] 320 | random.shuffle(hashes) 321 | keep_hashes.update(hashes[:max_size]) 322 | 323 | benchmark = [p for p in benchmark if p["hash"] in keep_hashes] 324 | print(Counter(p["dataset"] for p in benchmark)) 325 | 326 | # train-test split 327 | for application in benchmark: 328 | pos_samples, neg_samples = ( 329 | application["pos_samples"], 330 | application["neg_samples"], 331 | ) 332 | 333 | pos_samples = deepcopy(pos_samples) 334 | neg_samples = deepcopy(neg_samples) 335 | 336 | random.shuffle(pos_samples) 337 | random.shuffle(neg_samples) 338 | 339 | train_pos_samples = pos_samples[: int(len(pos_samples) * train_size)] 340 | train_neg_samples = neg_samples[: int(len(neg_samples) * train_size)] 341 | 342 | test_pos_samples = pos_samples[int(len(pos_samples) * train_size) :] 343 | test_neg_samples = neg_samples[int(len(neg_samples) * train_size) :] 344 | 345 | application["split"] = { 346 | "train": { 347 | "pos_samples": train_pos_samples, 348 | "neg_samples": train_neg_samples, 349 | }, 350 | "test": {"pos_samples": test_pos_samples, "neg_samples": test_neg_samples}, 351 | } 352 | 353 | benchmark = [reformat(problem) for problem in benchmark] 354 | 355 | return benchmark 356 | 357 | 358 | def reformat(example_problem): 359 | """ 360 | Final formatting for paper. 361 | """ 362 | 363 | example_problem = deepcopy(example_problem) 364 | example_problem["A_desc"] = example_problem["pos_desc"] 365 | example_problem["B_desc"] = example_problem["neg_desc"] 366 | example_problem.update(example_problem["application"]) 367 | 368 | del example_problem["pos_desc"] 369 | del example_problem["neg_desc"] 370 | del example_problem["pos_samples"] 371 | del example_problem["neg_samples"] 372 | del example_problem["pos_weights"] 373 | del example_problem["neg_weights"] 374 | del example_problem["application"] 375 | 376 | example_problem["split"] = { 377 | "research": { 378 | "A_samples": example_problem["split"]["train"]["pos_samples"], 379 | "B_samples": example_problem["split"]["train"]["neg_samples"], 380 | }, 381 | "validation": { 382 | "A_samples": example_problem["split"]["test"]["pos_samples"], 383 | "B_samples": example_problem["split"]["test"]["neg_samples"], 384 | }, 385 | } 386 | example_problem["dataset_abbreviation"] = example_problem["dataset"] 387 | if "example_hypotheses" not in example_problem: 388 | example_problem["example_hypotheses"] = [] 389 | 390 | for key in [ 391 | "dataset", 392 | "note", 393 | "hash", 394 | "avg_chars", 395 | "avg_words", 396 | "auc_roc_simple", 397 | "auc_roc_mnb", 398 | "pair_id", 399 | "application_idx_in_pair", 400 | "hypotheses_is_native", 401 | "v1_id", 402 | "v2_id", 403 | "pair_type", 404 | "pair", 405 | ]: 406 | if key in example_problem: 407 | del example_problem[key] 408 | 409 | return example_problem 410 | 411 | 412 | def describe_pair(pair): 413 | """Prints relevant information about dataset""" 414 | 415 | print("----") 416 | print(f'pair: {pair["pair"]}') 417 | print(f'# pos samples: {len(pair["pos_samples"])}') 418 | print(f'# neg samples: {len(pair["neg_samples"])}') 419 | print(f'avg chars: {pair["avg_chars"]:.2f}') 420 | print(f'avg words: {pair["avg_words"]:.2f}') 421 | if "auc_roc_simple" in pair and "auc_roc_mnb" in pair: 422 | print(f'simple auc: {pair["auc_roc_simple"]:.3f}') 423 | print(f'mnb auc: {pair["auc_roc_mnb"]:.3f}') 424 | 425 | 426 | def export_schema(pairs): 427 | pairs = deepcopy(pairs) 428 | for p in pairs: 429 | del p["pos_samples"] 430 | del p["neg_samples"] 431 | with open("schema/pairs.yaml", "w") as f: 432 | f.write(yaml.dump(pairs)) 433 | 434 | 435 | def main(): 436 | 437 | parser = argparse.ArgumentParser( 438 | prog="Make Pairs", description="Creates the pairs comprising the D3 benchmark." 439 | ) 440 | 441 | parser.add_argument("--full", action="store_true") 442 | parser.add_argument("--add_errors", action="store_true") 443 | parser.add_argument("--add_clusters", action="store_true") 444 | parser.add_argument("--add_datasets", action="store_true") 445 | parser.add_argument("--add_pairs", action="store_true") 446 | parser.add_argument("--add_spurious", action="store_true") 447 | parser.add_argument("--access", action="store_true") 448 | parser.add_argument("--test_dataset") 449 | 450 | args = parser.parse_args() 451 | if args.full: 452 | args.add_errors = True 453 | args.add_clusters = True 454 | args.add_datasets = True 455 | args.add_pairs = True 456 | args.add_spurious = True 457 | 458 | with open(join(SCHEMA_FOLDER, PAIRS_SCHEMA)) as f: 459 | pair_schema = yaml.load(f, Loader=SafeLoader) 460 | 461 | all_pairs = [] 462 | 463 | if args.test_dataset: 464 | dataset = args.test_dataset 465 | generations = pair_schema[dataset] 466 | for gen_dict in generations: 467 | if "note" not in gen_dict: 468 | gen_dict["note"] = "" 469 | pairs = gen_dict["pairs"] 470 | del gen_dict["pairs"] 471 | for pair_name, pair_info in pairs.items(): 472 | new_pair = generate_pair( 473 | dataset=dataset, 474 | pair=f"{dataset}_{pair_name}", 475 | **gen_dict, 476 | **pair_info, 477 | ) 478 | describe_pair(new_pair) 479 | all_pairs.append(new_pair) 480 | 481 | if args.add_errors: 482 | 483 | print("errors") 484 | task_errors_file = join(COMPONENTS_FOLDER, "error_analysis_1203.json") 485 | task_error_pairs = json.load(open(task_errors_file)) 486 | for pair in task_error_pairs: 487 | all_pairs.append(construct_pair(**pair)) 488 | 489 | if args.add_spurious: 490 | 491 | print("spurious") 492 | spurious_file = join(COMPONENTS_FOLDER, "spurious_0110.json") 493 | 494 | all_pairs.extend(add_spurious_data(spurious_file)) 495 | 496 | if args.add_clusters: 497 | 498 | print("clusters") 499 | 500 | for dataset in CLUSTER_DATASETS: 501 | print(dataset) 502 | sample_path = f"clusters/{dataset}/clusters/clusters.json" 503 | distance_path = f"clusters/{dataset}/clusters/l2_distance.json" 504 | 505 | cluster_pairs = add_cluster_data( 506 | sample_path=sample_path, 507 | distance_path=distance_path, 508 | dataset=dataset, 509 | ) 510 | 511 | all_pairs.extend(cluster_pairs) 512 | 513 | if args.add_datasets: 514 | 515 | print("datasets") 516 | dataset_constructors = [add_debate, add_amazon_reviews] 517 | 518 | for con in dataset_constructors: 519 | print("loading") 520 | pairs = con() 521 | print("processing") 522 | for pair in tqdm(pairs): 523 | all_pairs.append(construct_pair(**pair, truncate=False)) 524 | 525 | if args.add_pairs or args.access: 526 | 527 | print("pairs") 528 | 529 | if args.access: 530 | accessible_datasets = [ 531 | d 532 | for d, metadata in dataset_metadata.items() 533 | if metadata["status"] == "accessible" 534 | ] 535 | pair_schema = { 536 | dataset: generation 537 | for dataset, generation in pair_schema.items() 538 | if dataset in accessible_datasets 539 | } 540 | 541 | for dataset, generations in pair_schema.items(): 542 | for gen_dict in generations: 543 | if "note" not in gen_dict: 544 | gen_dict["note"] = "" 545 | pairs = gen_dict["pairs"] 546 | del gen_dict["pairs"] 547 | for pair_name, pair_info in pairs.items(): 548 | new_pair = generate_pair( 549 | dataset=dataset, 550 | pair=f"{dataset}_{pair_name}", 551 | **gen_dict, 552 | **pair_info, 553 | ) 554 | describe_pair(new_pair) 555 | all_pairs.append(new_pair) 556 | 557 | all_pairs = post_processing(all_pairs, remove_private=False) 558 | 559 | pkl.dump(all_pairs, open(BENCHMARK_FILE, "wb")) 560 | 561 | 562 | if __name__ == "__main__": 563 | main() 564 | -------------------------------------------------------------------------------- /scripts/parameters.py: -------------------------------------------------------------------------------- 1 | """ 2 | Relevant parameters for constructions. 3 | """ 4 | 5 | from datetime import datetime 6 | from os.path import join 7 | 8 | # folder names 9 | DOWNLOAD_FOLDER = "downloads" 10 | DATASET_FOLDER = "datasets" 11 | MANUAL_FOLDER = "manual" 12 | OUTPUT_FOLDER = "output" 13 | UNLABELED_FOLDER = "unlabeled" 14 | SCHEMA_FOLDER = "schema" 15 | PAIRS_SCHEMA = "pairs.yaml" 16 | DATASET_SCHEMA = "datasets.yaml" 17 | BENCHMARK_FOLDER = "benchmarks" 18 | COMPONENTS_FOLDER = "components" 19 | BENCHMARK_NAME = f'benchmark_{datetime.now().strftime("%m%d")}.pkl' 20 | BENCHMARK_FILE = join(BENCHMARK_FOLDER, BENCHMARK_NAME) 21 | 22 | SNIPPET_MAXLEN = 256 23 | 24 | CLUSTER_DATASETS = ["all_the_news", "wikitext", "debate", "poetry"] 25 | 26 | # pair structure 27 | PAIR_FIELDS = { 28 | "dataset", 29 | "generation", 30 | "applications", 31 | "pair_type", 32 | "pair", 33 | "pos_samples", 34 | "neg_samples", 35 | "pos_desc", 36 | "neg_desc", 37 | "flip", 38 | } 39 | 40 | APPLICATION_FIELDS = {"target", "user", "purely_exploratory"} 41 | 42 | # parameters 43 | DISC_THRESH = 0.58 44 | -------------------------------------------------------------------------------- /scripts/scrape_admin_statements.py: -------------------------------------------------------------------------------- 1 | """ 2 | Constructs the admin_statements distributions. 3 | """ 4 | 5 | import glob 6 | 7 | import pdfplumber 8 | from parameters import * 9 | from utils import * 10 | 11 | 12 | def scrape(): 13 | """ 14 | Scrapes the statements of administration policy from the 15 | statements-of-administration-policy-main repository. 16 | """ 17 | 18 | NAME = "admin_statements" 19 | 20 | directory = f"{DOWNLOAD_FOLDER}/{NAME}" 21 | 22 | administrations = ["44-Obama", "45-Trump", "46-Biden"] 23 | 24 | data = {} 25 | 26 | for admin in administrations: 27 | print(admin) 28 | 29 | files = glob.glob( 30 | f"{directory}/statements-of-administration-policy-main/archive/statements/{admin}/**/*.pdf" 31 | ) 32 | 33 | statements = [] 34 | 35 | for file in files: 36 | text = "" 37 | try: 38 | with pdfplumber.open(file) as pdf: 39 | for page in pdf.pages: 40 | text += page.extract_text() + " " 41 | 42 | if (loc := text.find("The Administration")) != 0: 43 | text = text[loc:].replace("\n", "") 44 | text = text.replace("*", "").strip() 45 | texts = split_delimiter_(text, "\n") 46 | texts = split_truncate(texts) 47 | statements.extend(texts) 48 | except: 49 | pass 50 | 51 | data[admin] = statements 52 | 53 | return data 54 | -------------------------------------------------------------------------------- /scripts/scrape_blm_countermovements.py: -------------------------------------------------------------------------------- 1 | """ 2 | Constructs the blm_countermovements distributions. 3 | """ 4 | 5 | import pandas as pd 6 | import tweepy 7 | from parameters import * 8 | from utils import encode_ascii 9 | 10 | CONSUMER_KEY = None 11 | CONSUMER_SECRET = None 12 | OAUTH_TOKEN = None 13 | OAUTH_TOKEN_SECRET = None 14 | 15 | auth = tweepy.OAuthHandler(CONSUMER_KEY, CONSUMER_SECRET) 16 | auth.set_access_token(OAUTH_TOKEN, OAUTH_TOKEN_SECRET) 17 | api = tweepy.API(auth, wait_on_rate_limit=True) 18 | 19 | 20 | def get_tweet_info(row): 21 | """ 22 | Calls Twitter API to get tweet text. 23 | """ 24 | 25 | id_of_tweet = int(row["ID"]) 26 | try: 27 | tweet = api.get_status(id_of_tweet) 28 | row["text"] = tweet.text 29 | except: 30 | row["text"] = None 31 | return row 32 | 33 | 34 | def scrape(): 35 | """ 36 | Scapes BLM countermovements Tweets. 37 | """ 38 | 39 | paths = { 40 | "all_lives_matter": f"{MANUAL_FOLDER}/blm_countermovements/AllLivesMatter_IDs.csv", 41 | "blue_lives_matter": f"{MANUAL_FOLDER}/blm_countermovements/BlueLivesMatter_IDs.csv", 42 | "white_lives_matter": f"{MANUAL_FOLDER}/blm_countermovements/WhiteLivesMatter_IDs.csv", 43 | } 44 | 45 | df = pd.DataFrame() 46 | for movement, filepath in paths.items(): 47 | with open(filepath, "r") as f: 48 | IDs = f.readlines()[1:] 49 | movement_df = pd.DataFrame({"movement": movement, "ID": IDs}).sample( 50 | n=1000, replace=False, random_state=0 51 | ) 52 | df = df.append(movement_df) 53 | 54 | df = df.apply(get_tweet_info, axis=1) 55 | df = df.dropna(axis=0) 56 | 57 | data = {} 58 | for movement in paths: 59 | data[movement] = df[df.movement == movement].text.apply(encode_ascii).tolist() 60 | 61 | return data 62 | -------------------------------------------------------------------------------- /scripts/scrape_open_review.py: -------------------------------------------------------------------------------- 1 | """ 2 | Constructs the open_review distributions. 3 | """ 4 | 5 | from collections import defaultdict 6 | 7 | import numpy as np 8 | import openreview 9 | import pandas as pd 10 | 11 | client = openreview.Client(baseurl="https://api.openreview.net") 12 | 13 | 14 | def scrape(): 15 | """ 16 | Uses OenReview API to scrape ICLR reviews. 17 | """ 18 | 19 | invites = [ 20 | ( 21 | 2018, 22 | "ICLR.cc/2018/Conference/-/Blind_Submission", 23 | "ICLR.cc/2018/Conference/-/Paper.*/Official_Review", 24 | ), 25 | ( 26 | 2019, 27 | "ICLR.cc/2019/Conference/-/Blind_Submission", 28 | "ICLR.cc/2019/Conference/-/Paper.*/Official_Review", 29 | ), 30 | ( 31 | 2020, 32 | "ICLR.cc/2020/Conference/-/Blind_Submission", 33 | "ICLR.cc/2020/Conference/Paper.*/-/Official_Review", 34 | ), 35 | ( 36 | 2021, 37 | "ICLR.cc/2021/Conference/-/Blind_Submission", 38 | "ICLR.cc/2021/Conference/Paper.*/-/Official_Review", 39 | ), 40 | ] 41 | 42 | metadata = [] 43 | 44 | for year, submission_invite, review_invite in invites: 45 | 46 | submissions = openreview.tools.iterget_notes( 47 | client, invitation=submission_invite 48 | ) 49 | submissions_by_forum = {n.forum: n for n in submissions} 50 | 51 | reviews = openreview.tools.iterget_notes(client, invitation=review_invite) 52 | reviews_by_forum = defaultdict(list) 53 | for review in reviews: 54 | reviews_by_forum[review.forum].append(review) 55 | 56 | for forum in submissions_by_forum: 57 | 58 | forum_reviews = reviews_by_forum[forum] 59 | review_ratings = [int(n.content["rating"][0]) for n in forum_reviews] 60 | average_rating = np.mean(review_ratings) 61 | 62 | submission_content = submissions_by_forum[forum].content 63 | abstract = submission_content["abstract"] 64 | 65 | forum_metadata = { 66 | "forum": forum, 67 | "review_ratings": review_ratings, 68 | "average_rating": average_rating, 69 | "abstract": abstract, 70 | "year": year, 71 | } 72 | 73 | metadata.append(forum_metadata) 74 | 75 | df = pd.DataFrame(metadata) 76 | great_papers = df[df.average_rating >= 7].abstract.tolist() 77 | good_papers = df[ 78 | (df.average_rating >= 5) & (df.average_rating < 7) 79 | ].abstract.tolist() 80 | bad_papers = df[df.average_rating < 5].abstract.tolist() 81 | 82 | return great_papers, good_papers, bad_papers 83 | -------------------------------------------------------------------------------- /scripts/scrape_twitter_rumors.py: -------------------------------------------------------------------------------- 1 | """ 2 | Constructs the twitter_rumors distributions. 3 | """ 4 | 5 | import pandas as pd 6 | import tweepy 7 | from utils import encode_ascii 8 | 9 | CONSUMER_KEY = None 10 | CONSUMER_SECRET = None 11 | OAUTH_TOKEN = None 12 | OAUTH_TOKEN_SECRET = None 13 | 14 | auth = tweepy.OAuthHandler(CONSUMER_KEY, CONSUMER_SECRET) 15 | auth.set_access_token(OAUTH_TOKEN, OAUTH_TOKEN_SECRET) 16 | api = tweepy.API(auth, wait_on_rate_limit=True) 17 | 18 | 19 | def get_tweet_info(row): 20 | """ 21 | Calls Twitter API to get tweet text. 22 | """ 23 | 24 | id_of_tweet = int(row["Tweet ID"]) 25 | try: 26 | tweet = api.get_status(id_of_tweet) 27 | row["text"] = tweet.text 28 | row["dt"] = tweet.created_at 29 | except: 30 | row["text"] = None 31 | row["dt"] = None 32 | return row 33 | 34 | 35 | STAGES = ["early", "mid", "late"] 36 | 37 | 38 | def scrape(): 39 | """ 40 | Downloads Tweet IDs from Zenodo and scrapes Tweets from Twitter API. 41 | """ 42 | 43 | URLS = { 44 | "redhawks": "https://zenodo.org/record/2563864/files/DATASET_R1.xlsx", 45 | "muslim_waitress": "https://zenodo.org/record/2563864/files/DATASET_R2.xlsx", 46 | "zuckerberg_yatch": "https://zenodo.org/record/2563864/files/DATASET_R3.xlsx", 47 | "denzel_washington": "https://zenodo.org/record/2563864/files/DATASET_R4.xlsx", 48 | "veggietales": "https://zenodo.org/record/2563864/files/DATASET_R7.xlsx", 49 | "michael_jordan": "https://zenodo.org/record/2563864/files/DATASET_R8.xlsx", 50 | } 51 | 52 | data = {} 53 | 54 | for rumor, url in URLS.items(): 55 | 56 | print(rumor) 57 | 58 | df = pd.read_excel(url) 59 | df = df.sample(300) 60 | df = df.apply(get_tweet_info, axis=1) 61 | df = df.dropna(axis=0) 62 | df["stage"] = pd.qcut(df["dt"], 3, labels=STAGES) 63 | 64 | for stage in STAGES: 65 | data[f"{rumor}_{stage}"] = ( 66 | df[df.stage == stage]["text"].apply(encode_ascii).tolist() 67 | ) 68 | 69 | return data 70 | -------------------------------------------------------------------------------- /scripts/test_discriminative.py: -------------------------------------------------------------------------------- 1 | """ 2 | Baselines for problem difficulty. 3 | """ 4 | 5 | import random 6 | from typing import List 7 | 8 | import numpy as np 9 | from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer 10 | from sklearn.feature_selection import SelectKBest, chi2 11 | from sklearn.metrics import roc_auc_score 12 | from sklearn.naive_bayes import MultinomialNB 13 | from sklearn.pipeline import Pipeline 14 | 15 | SAMPLE_SIZE = 1000 16 | 17 | 18 | def discriminated_w_simple_feature( 19 | positive_samples: List[str], 20 | negative_samples: List[str], 21 | sample_size: int = SAMPLE_SIZE, 22 | k: int = 20, 23 | ) -> float: 24 | """ 25 | Assess basic discriminatory power with top k words, 26 | length, number of words, capital letters, and numbers. 27 | Returns max(auc, 1-auc) for the basic single feature. 28 | """ 29 | 30 | # sample down if necessary 31 | random.seed(2022) 32 | if len(positive_samples) > sample_size: 33 | positive_samples = random.sample( 34 | positive_samples, 35 | sample_size, 36 | ) 37 | if len(negative_samples) > sample_size: 38 | negative_samples = random.sample(negative_samples, sample_size) 39 | 40 | # ground truth 41 | labels = np.append(np.ones(len(positive_samples)), np.zeros(len(negative_samples))) 42 | 43 | def eval_power(preds: List[float]): 44 | """Evaluates predictions against ground truth""" 45 | roc_auc = roc_auc_score(labels, preds) 46 | return max(roc_auc, 1 - roc_auc) 47 | 48 | # vectorize 49 | vectorizer = CountVectorizer() 50 | all_samples = positive_samples + negative_samples 51 | vectorizer.fit(all_samples) 52 | pos_counts = vectorizer.transform(positive_samples) 53 | neg_counts = vectorizer.transform(negative_samples) 54 | 55 | # count freqs 56 | pos_freq = pos_counts.sum(axis=0) 57 | neg_freq = neg_counts.sum(axis=0) 58 | net_freq = pos_freq - neg_freq 59 | arr = abs(np.array(net_freq.flatten())[0]) 60 | args = np.argsort(arr)[-k:] 61 | 62 | def test_word_power(arg: int): 63 | """Tests a given word index discriminatory power""" 64 | pos_pred = pos_counts[:, arg].toarray().flatten() > 0 65 | neg_pred = neg_counts[:, arg].toarray().flatten() > 0 66 | preds = np.append(pos_pred, neg_pred) 67 | return eval_power(preds) 68 | 69 | # engineer other features 70 | max_word_auc = max(map(test_word_power, args)) 71 | pos_words = np.asarray(pos_counts.sum(axis=1)).flatten() 72 | neg_words = np.asarray(neg_counts.sum(axis=1)).flatten() 73 | total_words_auc = eval_power(np.append(pos_words, neg_words)) 74 | len_auc = eval_power(list(map(len, all_samples))) 75 | number_idxs = [idx for w, idx in vectorizer.vocabulary_.items() if w.isnumeric()] 76 | pos_nums = np.asarray(pos_counts[:, number_idxs].sum(axis=1)).flatten() 77 | neg_nums = np.asarray(neg_counts[:, number_idxs].sum(axis=1)).flatten() 78 | num_auc = eval_power(np.append(pos_nums, neg_nums)) 79 | capital_auc = eval_power(list(map(lambda s: s[0].isupper(), all_samples))) 80 | 81 | # return max discrim 82 | return max(max_word_auc, total_words_auc, len_auc, capital_auc, num_auc) 83 | 84 | 85 | def discriminated_w_mnb( 86 | positive_samples: List[str], 87 | negative_samples: List[str], 88 | sample_size: int = SAMPLE_SIZE, 89 | k: int = 20, 90 | ) -> float: 91 | """ 92 | Assess basic discriminatory power with top k words, 93 | length, number of words, capital letters, and numbers. 94 | Returns max(auc, 1-auc) for the basic single feature. 95 | """ 96 | 97 | # build pipeline 98 | text_clf = Pipeline( 99 | [ 100 | ("vect", CountVectorizer()), 101 | ("tfidf", TfidfTransformer()), 102 | ("sel", SelectKBest(chi2, k=k)), 103 | ("clf", MultinomialNB()), 104 | ] 105 | ) 106 | 107 | # sample down 108 | if len(positive_samples) > sample_size: 109 | positive_samples = random.sample(positive_samples, sample_size) 110 | if len(negative_samples) > sample_size: 111 | negative_samples = random.sample(negative_samples, sample_size) 112 | 113 | # build design 114 | X = positive_samples + negative_samples # ground truth 115 | y = np.append(np.ones(len(positive_samples)), np.zeros(len(negative_samples))) 116 | 117 | # fit classifier 118 | clf = text_clf.fit(X, y) 119 | 120 | return roc_auc_score(y, clf.predict_proba(X)[:, 1]) 121 | 122 | 123 | if __name__ == "__main__": 124 | 125 | import pickle as pkl 126 | 127 | pairs = pkl.load(open("benchmarks/benchmark_1201.pkl", "rb")) 128 | for pair in pairs: 129 | positive_samples = pair["pos_samples"] 130 | negative_samples = pair["neg_samples"] 131 | s = discriminated_w_simple_feature(positive_samples, negative_samples) 132 | print(s) 133 | print(pair["pair"]) 134 | exit(0) 135 | -------------------------------------------------------------------------------- /scripts/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | General utiltiies. 3 | """ 4 | 5 | import gzip 6 | import json 7 | import os 8 | import shutil 9 | import tarfile 10 | from html.parser import HTMLParser 11 | from io import BytesIO, StringIO 12 | from itertools import chain 13 | from os.path import join 14 | from typing import Callable, Dict, List, Tuple, Union 15 | from zipfile import ZipFile 16 | 17 | import gdown 18 | import nltk 19 | import pandas as pd 20 | import requests 21 | from markdown import Markdown 22 | from nltk.tokenize import sent_tokenize 23 | from parameters import * 24 | 25 | """ 26 | ============= 27 | Text cleaning 28 | ============= 29 | """ 30 | 31 | 32 | def remove_empty(samples: Union[List[str], List[Tuple]]) -> List[str]: 33 | """Utility for removing empty strings""" 34 | if not samples: 35 | return samples 36 | s = samples[0] 37 | if isinstance(s, str): 38 | return [s for s in samples if s.strip()] 39 | return [s for s in samples if s[0].strip()] 40 | 41 | 42 | def split_delimiter( 43 | snippets: List[str], delimiter: str = "\n" 44 | ) -> List[Tuple[str, float]]: 45 | """ 46 | Splits a list of snippets on a specific delimiter and 47 | returns snippets weighted by word count. 48 | """ 49 | return list(chain(*[split_delimiter_(w_s, delimiter) for w_s in snippets])) 50 | 51 | 52 | def split_delimiter_(snippet: str, delimiter: str = "\n"): 53 | """ 54 | Splits a snippet based on a delimiter and returns 55 | snippets weighted by word count. 56 | """ 57 | new_snippets = [] 58 | 59 | total_weight = len(snippet.split()) 60 | split_snips = snippet.split(delimiter) 61 | for split_snip in split_snips: 62 | weight = len(split_snip.split()) 63 | if weight: 64 | new_snippets.append((split_snip, weight / total_weight)) 65 | 66 | return new_snippets 67 | 68 | 69 | def give_full_weight(snippets: List[str]) -> List[Tuple[str, int]]: 70 | """Assigns weight 1 to every snippet""" 71 | return [(snippet, 1) for snippet in snippets] 72 | 73 | 74 | def split_truncate( 75 | weighted_snippets: Union[List[Tuple[str, float]], List[str]], 76 | cap: int = SNIPPET_MAXLEN, 77 | ) -> List[str]: 78 | """Applies split_truncate_ to a list of snippets""" 79 | if isinstance(weighted_snippets[0], str): 80 | weighted_snippets = give_full_weight(weighted_snippets) 81 | return list(chain(*[split_truncate_(w_s, cap=cap) for w_s in weighted_snippets])) 82 | 83 | 84 | def split_truncate_( 85 | weighted_snippet: Tuple[str, float], cap: int = SNIPPET_MAXLEN 86 | ) -> List[Tuple[str, float]]: 87 | """ 88 | Splits any snippet over the truncation limit and returns 89 | new snippets weighted by word count. 90 | """ 91 | if isinstance(weighted_snippet, str): 92 | weighted_snippet = (weighted_snippet, 1) 93 | assert len(weighted_snippet) == 2, "Did not provide weighted snippets" 94 | 95 | snippet, weight = weighted_snippet 96 | assert len(snippet.split()) > 0, f"String is empty: {snippet}" 97 | new_snippets = [] 98 | new_snippet = "" 99 | word_count = 0 100 | 101 | # split snippet into sentences 102 | total_weight = len(snippet.split()) 103 | sentences = sent_tokenize(snippet) 104 | 105 | for sentence in sentences: 106 | 107 | sentence_length = len(sentence.split()) 108 | 109 | # if we need to tokenize 110 | if word_count + sentence_length > cap: 111 | new_snippet = new_snippet.strip() 112 | new_weight = word_count / total_weight * weight # new weight 113 | new_snippets.append((new_snippet, new_weight)) 114 | new_snippet, word_count = "", 0 115 | 116 | # check if sentence is too long 117 | if sentence_length > cap: 118 | continue 119 | 120 | # else add sentence 121 | new_snippet += sentence + " " 122 | word_count += sentence_length 123 | 124 | # add additional snippet 125 | new_snippet = new_snippet.strip() 126 | new_weight = word_count / total_weight * weight # new weight 127 | new_snippets.append((new_snippet, new_weight)) 128 | 129 | return new_snippets 130 | 131 | 132 | def split_truncate_word( 133 | weighted_snippets: Union[List[Tuple[str, float]], List[str]], 134 | cap: int = SNIPPET_MAXLEN, 135 | ) -> List[str]: 136 | """Applies split_truncate_ to a list of snippets""" 137 | 138 | if isinstance(weighted_snippets[0], str): 139 | weighted_snippets = give_full_weight(weighted_snippets) 140 | return list( 141 | chain(*[split_truncate_word_(w_s, cap=cap) for w_s in weighted_snippets]) 142 | ) 143 | 144 | 145 | def split_truncate_word_( 146 | weighted_snippet: Tuple[str, float], cap: int = SNIPPET_MAXLEN 147 | ): 148 | """ 149 | Splits any snippet over the truncation limit at the 150 | level of the word and returns new snippets weighted by word count. 151 | """ 152 | if isinstance(weighted_snippet, str): 153 | weighted_snippet = (weighted_snippet, 1) 154 | assert len(weighted_snippet) == 2, "Did not provide weighted snippets" 155 | 156 | snippet, weight = weighted_snippet 157 | assert len(snippet.split()) > 0, f"String is empty: {snippet}" 158 | 159 | words = snippet.split() 160 | total_words = len(words) 161 | 162 | # split snippet into sentences 163 | i = 0 164 | while (i + 1) * cap < total_words: 165 | new_snippet = " ".join(words[i * cap : (i + 1) * cap]) 166 | yield (new_snippet, cap / total_words * weight) 167 | i += 1 168 | 169 | tail = total_words - i * cap 170 | if tail: 171 | new_snippet = " ".join(words[i * cap :]) 172 | yield (new_snippet, tail / total_words * weight) 173 | 174 | 175 | def split_df(df: pd.DataFrame, text_col: str, splitter: Callable = split_truncate_): 176 | """ 177 | Accepts a DataFrame and a column of text, applies truncation and 178 | splits rows into multiple as necessary. 179 | """ 180 | assert text_col in df.columns, f"Columns {text_col} not in DataFrame" 181 | 182 | df[text_col] = df[text_col].apply(splitter) 183 | df = df.explode(text_col) 184 | 185 | df = df[df[text_col].str.len() > 0] 186 | 187 | return df 188 | 189 | 190 | def shorten_snippet(snippet: str, cap: int = 256) -> str: 191 | """ 192 | Shortens a text subject to some limit. 193 | """ 194 | 195 | sentences = sent_tokenize(snippet) 196 | new_snippet = "" 197 | num_words = 0 198 | 199 | for sentence in sentences: 200 | num_words = len(sentence.split()) 201 | if num_words > cap: 202 | return new_snippet.strip() 203 | new_snippet += sentence + " " 204 | 205 | return new_snippet.strip() 206 | 207 | 208 | def sentence_tokenize(snippets: List[str]) -> List[str]: 209 | """ 210 | Uses NLTK to sentence toknize a list of snippets. 211 | """ 212 | 213 | all_sentences = [] 214 | for snippet in snippets: 215 | sentences = nltk.tokenize.sent_tokenize(snippet) 216 | all_sentences += sentences 217 | 218 | return all_sentences 219 | 220 | 221 | class MLStripper(HTMLParser): 222 | """ 223 | Class to strip HTML from a string. 224 | 225 | Borrowed from https://stackoverflow.com/questions/753052/strip-html-from-strings-in-python. 226 | """ 227 | 228 | def __init__(self): 229 | super().__init__() 230 | self.reset() 231 | self.strict = False 232 | self.convert_charrefs = True 233 | self.text = StringIO() 234 | 235 | def handle_data(self, d): 236 | self.text.write(d) 237 | 238 | def get_data(self): 239 | return self.text.getvalue() 240 | 241 | 242 | def strip_tags(html: str) -> str: 243 | """Removes HTML from a string.""" 244 | s = MLStripper() 245 | s.feed(html) 246 | return s.get_data() 247 | 248 | 249 | def encode_ascii(text: str) -> str: 250 | """Encode a text into ASCII.""" 251 | return text.encode("ascii", "ignore").decode("utf-8") 252 | 253 | 254 | def unmark_element(element, stream=None): 255 | """Outputs plain, unformatted markdown.""" 256 | if stream is None: 257 | stream = StringIO() 258 | if element.text: 259 | stream.write(element.text) 260 | for sub in element: 261 | unmark_element(sub, stream) 262 | if element.tail: 263 | stream.write(element.tail) 264 | return stream.getvalue() 265 | 266 | 267 | Markdown.output_formats["plain"] = unmark_element 268 | __md = Markdown(output_format="plain") 269 | __md.stripTopLevelTags = False 270 | 271 | 272 | def unmark(text: str) -> str: 273 | """Removes markdown formatting from a Markdown object.""" 274 | return __md.convert(text) 275 | 276 | 277 | """ 278 | ============= 279 | File download 280 | ============= 281 | """ 282 | 283 | 284 | def download_zip(url: str, directory: str): 285 | """Downloads and extracts contents of a zip folder.""" 286 | 287 | req = requests.get(url) 288 | zip = ZipFile(BytesIO(req.content)) 289 | zip.extractall(directory) 290 | 291 | 292 | def extract_zip(path: str, directory: str): 293 | """Extracts all files from a zip folder.""" 294 | with ZipFile(path, "r") as zip: 295 | zip.extractall(directory) 296 | 297 | 298 | def download_tar(url: str, directory: str): 299 | """Downloads and extracts contents of a tar file.""" 300 | 301 | response = requests.get(url, stream=True) 302 | file = tarfile.open(fileobj=response.raw, mode="r|gz") 303 | file.extractall(path=directory) 304 | 305 | 306 | def download_file(url: str, directory: str, filename: str): 307 | """Downloads and names file.""" 308 | req = requests.get(url) 309 | os.makedirs(directory, exist_ok=True) 310 | with open(join(directory, filename), "wb") as f: 311 | f.write(req.content) 312 | 313 | 314 | def download_drive_file(id: str, directory: str, filename: str): 315 | """Downloads files from Google Drive.""" 316 | 317 | url = f"https://drive.google.com/uc?id={id}" 318 | os.makedirs(directory, exist_ok=True) 319 | gdown.download(url, join(directory, filename)) 320 | 321 | 322 | def download_drive_zip(id: str, directory: str) -> None: 323 | """Downloads files from Google Drive.""" 324 | 325 | url = f"https://drive.google.com/uc?id={id}" 326 | os.makedirs(directory, exist_ok=True) 327 | gdown.download(url, join(directory, "drive.zip")) 328 | with ZipFile(join(directory, "drive.zip"), "r") as zip_ref: 329 | zip_ref.extractall(directory) 330 | 331 | 332 | def download_gz(url, directory, filename) -> None: 333 | """Downloads and opens gzip file.""" 334 | 335 | download_file(url, directory, filename + ".gz") 336 | 337 | with gzip.open(join(directory, filename + ".gz"), "rb") as f_in: 338 | with open(join(directory, filename), "wb") as f_out: 339 | shutil.copyfileobj(f_in, f_out) 340 | 341 | 342 | """ 343 | ========= 344 | File save 345 | ========= 346 | """ 347 | 348 | 349 | def save_dataset(df: pd.DataFrame, name: str): 350 | """Saves dataframe to datasets folder.""" 351 | 352 | os.makedirs(DATASET_FOLDER, exist_ok=True) 353 | df.to_pickle(join(DATASET_FOLDER, f"{name}.pkl")) 354 | 355 | 356 | def save_output_json(data: Dict, name: str): 357 | """Saves output data to output folder.""" 358 | 359 | output = { 360 | "name": name, 361 | "data": data, 362 | } 363 | 364 | output_file = f"{OUTPUT_FOLDER}/{name}.json" 365 | with open(output_file, "w") as outfile: 366 | json.dump(output, outfile) 367 | 368 | 369 | def save_unlabeled_json(sentences: List[str], name: str): 370 | """Saves unlabeled data to output folder.""" 371 | 372 | os.makedirs(UNLABELED_FOLDER, exist_ok=True) 373 | output_file = f"{UNLABELED_FOLDER}/{name}.json" 374 | with open(output_file, "w") as outfile: 375 | json.dump(sentences, outfile) 376 | 377 | 378 | def delete_downloads(): 379 | """Clears the downloads folder.""" 380 | shutil.rmtree(DOWNLOAD_FOLDER) 381 | --------------------------------------------------------------------------------