├── .dockerignore ├── .gitignore ├── .skiff ├── cloudbuild-deploy.yaml ├── norms.skiff.json ├── scoracle.skiff.json └── webapp.jsonnet ├── Dockerfile ├── LICENSE ├── MANIFEST.in ├── bin └── scruples ├── docs ├── annotation-guidelines.md ├── assets │ └── demo-screenshots │ │ ├── norms-about.png │ │ ├── norms-results.png │ │ ├── scoracle-about.png │ │ └── scoracle-results.png └── demos.md ├── hit-definitions └── resource-labeling │ ├── data.jsonl │ └── definition │ ├── NOTES │ ├── hitproperties.json │ ├── hittypeproperties.json │ └── question.xml.j2 ├── readme.md ├── requirements.txt ├── setup.cfg ├── setup.py ├── src └── scruples │ ├── __init__.py │ ├── analysis │ ├── __init__.py │ └── traits.py │ ├── baselines │ ├── __init__.py │ ├── bert.py │ ├── corpus.py │ ├── labels.py │ ├── linear.py │ ├── loss.py │ ├── metrics.py │ ├── naivebayes.py │ ├── resource.py │ ├── roberta.py │ ├── style.py │ ├── train.py │ ├── trees.py │ └── utils.py │ ├── data │ ├── __init__.py │ ├── action.py │ ├── comment.py │ ├── label_scores.py │ ├── labels.py │ ├── post.py │ ├── post_types.py │ └── utils.py │ ├── dataset │ ├── __init__.py │ ├── readers.py │ └── transforms.py │ ├── demos │ ├── __init__.py │ ├── norms │ │ ├── __init__.py │ │ ├── app.py │ │ ├── static │ │ │ ├── favicon.ico │ │ │ └── logo-ai2-white-withText-micro.svg │ │ ├── templates │ │ │ └── index.html │ │ └── utils.py │ └── scoracle │ │ ├── __init__.py │ │ ├── app.py │ │ ├── static │ │ ├── favicon.ico │ │ └── logo-ai2-white-withText-micro.svg │ │ └── templates │ │ ├── dev.scruples-anecdotes_label-counts.json │ │ ├── dev.scruples-dilemmas_label-counts.json │ │ ├── dev.snli_label-counts.json │ │ └── index.html │ ├── extraction │ ├── __init__.py │ ├── base.py │ ├── filters.py │ ├── normalizers.py │ └── transformers.py │ ├── scripts │ ├── __init__.py │ ├── analyze │ │ ├── __init__.py │ │ ├── corpus │ │ │ ├── __init__.py │ │ │ ├── extractions.py │ │ │ ├── human_performance.py │ │ │ ├── oracle_performance.py │ │ │ ├── predictions.py │ │ │ └── statistics.py │ │ ├── oracle_estimator.py │ │ └── resource │ │ │ ├── __init__.py │ │ │ ├── human_performance.py │ │ │ ├── latent_traits.py │ │ │ ├── oracle_performance.py │ │ │ ├── predictions.py │ │ │ ├── topics.py │ │ │ └── verbs.py │ ├── demo │ │ ├── __init__.py │ │ ├── norms.py │ │ └── scoracle.py │ ├── evaluate │ │ ├── __init__.py │ │ ├── corpus │ │ │ ├── __init__.py │ │ │ ├── predict_lm.py │ │ │ ├── run_shallow.py │ │ │ └── tune_lm.py │ │ └── resource │ │ │ ├── __init__.py │ │ │ ├── predict_lm.py │ │ │ ├── run_shallow.py │ │ │ └── tune_lm.py │ └── make │ │ ├── __init__.py │ │ ├── corpus.py │ │ └── resource │ │ ├── __init__.py │ │ ├── dataset.py │ │ └── proposals.py │ ├── settings.py │ ├── utils.py │ └── vendor │ ├── __init__.py │ └── skopt.py └── tests ├── __init__.py ├── conftest.py ├── fixtures ├── corpus-easy │ ├── dev.scruples-corpus.jsonl │ ├── test.scruples-corpus.jsonl │ └── train.scruples-corpus.jsonl ├── gold-extractions │ └── gold-title-description-extractions.jsonl ├── raw-data │ ├── test-comment.jsonl │ ├── test-post-comments.jsonl │ └── test-post.jsonl └── resource-easy │ ├── dev.scruples-actions.jsonl │ ├── test.scruples-actions.jsonl │ └── train.scruples-actions.jsonl ├── scruples ├── __init__.py ├── analysis │ ├── __init__.py │ └── test_traits.py ├── baselines │ ├── __init__.py │ ├── test_labels.py │ ├── test_linear.py │ ├── test_loss.py │ ├── test_metrics.py │ ├── test_naivebayes.py │ ├── test_style.py │ ├── test_utils.py │ └── utils.py ├── data │ ├── __init__.py │ ├── test_action.py │ ├── test_comment.py │ ├── test_label_scores.py │ ├── test_labels.py │ ├── test_post.py │ ├── test_post_types.py │ └── test_utils.py ├── dataset │ ├── __init__.py │ ├── test_readers.py │ └── test_transforms.py ├── demos │ ├── __init__.py │ ├── norms │ │ ├── __init__.py │ │ ├── test_app.py │ │ └── test_utils.py │ └── scoracle │ │ ├── __init__.py │ │ └── test_app.py ├── extraction │ ├── __init__.py │ ├── test_base.py │ ├── test_filters.py │ ├── test_normalizers.py │ └── test_transformers.py ├── test_settings.py └── test_utils.py ├── settings.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | db.sqlite3 61 | 62 | # Flask stuff: 63 | instance/ 64 | .webassets-cache 65 | 66 | # Scrapy stuff: 67 | .scrapy 68 | 69 | # Sphinx documentation 70 | docs/_build/ 71 | 72 | # PyBuilder 73 | target/ 74 | 75 | # Jupyter Notebook 76 | .ipynb_checkpoints 77 | 78 | # IPython 79 | profile_default/ 80 | ipython_config.py 81 | 82 | # pyenv 83 | .python-version 84 | 85 | # pipenv 86 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 87 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 88 | # having no cross-platform support, pipenv may install dependencies that don’t work, or not 89 | # install all needed dependencies. 90 | #Pipfile.lock 91 | 92 | # celery beat schedule file 93 | celerybeat-schedule 94 | 95 | # SageMath parsed files 96 | *.sage.py 97 | 98 | # Environments 99 | .env 100 | .venv 101 | env/ 102 | venv/ 103 | ENV/ 104 | env.bak/ 105 | venv.bak/ 106 | 107 | # Spyder project settings 108 | .spyderproject 109 | .spyproject 110 | 111 | # Rope project settings 112 | .ropeproject 113 | 114 | # mkdocs documentation 115 | /site 116 | 117 | # mypy 118 | .mypy_cache/ 119 | .dmypy.json 120 | dmypy.json 121 | 122 | # Pyre type checker 123 | .pyre/ -------------------------------------------------------------------------------- /.skiff/cloudbuild-deploy.yaml: -------------------------------------------------------------------------------- 1 | # This file tells Google Cloud Build how to deploy the application. 2 | # It can be attached to a variety of triggers, the default being whenever 3 | # someone merges changes to the `master` branch. 4 | timeout: '3600s' 5 | steps: 6 | # Pull the latest Docker image, so the build is faster. 7 | - id: 'image.pull' 8 | waitFor: ['-'] 9 | name: 'gcr.io/cloud-builders/docker' 10 | entrypoint: '/bin/bash' 11 | args: [ 12 | '-c', 13 | 'docker pull gcr.io/$PROJECT_ID/$REPO_NAME:latest || exit 0' 14 | ] 15 | # Build the Docker image 16 | - id: 'image.build' 17 | waitFor: ['image.pull'] 18 | name: 'gcr.io/cloud-builders/docker' 19 | args: [ 20 | 'build', 21 | '--tag', 'gcr.io/$PROJECT_ID/$REPO_NAME:latest', 22 | '--tag', 'gcr.io/$PROJECT_ID/$REPO_NAME:$COMMIT_SHA', 23 | '--cache-from', 'gcr.io/$PROJECT_ID/$REPO_NAME:latest', 24 | '.', 25 | ] 26 | # Push the Docker image to be deployed 27 | - id: 'image.push' 28 | waitFor: ['image.build'] 29 | name: 'gcr.io/cloud-builders/docker' 30 | args: [ 31 | 'push', 32 | 'gcr.io/$PROJECT_ID/$REPO_NAME:$COMMIT_SHA', 33 | ] 34 | # Generate the Kubernetes configuration for the norms demo 35 | - id: 'config.norms' 36 | waitFor: ['-'] 37 | name: 'gcr.io/ai2-reviz/jsonnet' 38 | args: [ 39 | 'eval', 40 | '--yaml-stream', 41 | '--output-file', './norms.webapp.yaml', 42 | '--tla-str', 'app=norms', 43 | '--tla-str', 'env=$_ENV', 44 | '--tla-str', 'image=gcr.io/$PROJECT_ID/$REPO_NAME:$COMMIT_SHA', 45 | '--tla-str', 'sha=$COMMIT_SHA', 46 | '--tla-str', 'cause=Automated Skiff Deploy SHA:$COMMIT_SHA BUILD:$BUILD_ID', 47 | '--tla-str', 'branch=$BRANCH_NAME', 48 | '--tla-str', 'repo=$REPO_NAME', 49 | '--tla-str', 'buildId=$BUILD_ID', 50 | './webapp.jsonnet' 51 | ] 52 | dir: '.skiff' 53 | # Generate the Kubernetes configuration for the scoracle demo 54 | - id: 'config.scoracle' 55 | waitFor: ['-'] 56 | name: 'gcr.io/ai2-reviz/jsonnet' 57 | args: [ 58 | 'eval', 59 | '--yaml-stream', 60 | '--output-file', './scoracle.webapp.yaml', 61 | '--tla-str', 'app=scoracle', 62 | '--tla-str', 'env=$_ENV', 63 | '--tla-str', 'image=gcr.io/$PROJECT_ID/$REPO_NAME:$COMMIT_SHA', 64 | '--tla-str', 'sha=$COMMIT_SHA', 65 | '--tla-str', 'cause=Automated Skiff Deploy SHA:$COMMIT_SHA BUILD:$BUILD_ID', 66 | '--tla-str', 'branch=$BRANCH_NAME', 67 | '--tla-str', 'repo=$REPO_NAME', 68 | '--tla-str', 'buildId=$BUILD_ID', 69 | './webapp.jsonnet' 70 | ] 71 | dir: '.skiff' 72 | # Deploy the norms demo to Kubernetes 73 | - id: 'deploy.norms' 74 | waitFor: ['image.push', 'config.norms'] 75 | name: 'gcr.io/ai2-reviz/rudder' 76 | args: [ 77 | 'deploy', 78 | '-f', 79 | 'norms.webapp.yaml' 80 | ] 81 | dir: '.skiff' 82 | # Deploy the scoracle demo to Kubernetes 83 | - id: 'deploy.scoracle' 84 | waitFor: ['image.push', 'config.scoracle'] 85 | name: 'gcr.io/ai2-reviz/rudder' 86 | args: [ 87 | 'deploy', 88 | '-f', 89 | 'scoracle.webapp.yaml' 90 | ] 91 | dir: '.skiff' 92 | substitutions: 93 | _ENV: staging 94 | images: [ 95 | 'gcr.io/$PROJECT_ID/$REPO_NAME:$COMMIT_SHA', 96 | 'gcr.io/$PROJECT_ID/$REPO_NAME:latest' 97 | ] 98 | artifacts: 99 | objects: 100 | location: 'gs://skiff-archive/$REPO_NAME/$_ENV/$BUILD_ID/$COMMIT_SHA' 101 | paths: ['.skiff/norms.webapp.yaml', '.skiff/scoracle.webapp.yaml'] 102 | -------------------------------------------------------------------------------- /.skiff/norms.skiff.json: -------------------------------------------------------------------------------- 1 | { 2 | "appName": "norms", 3 | "contact": "ronanlb", 4 | "team": "aristo", 5 | "args": [ "demo", "norms", "--port", "8000", "--expose" ], 6 | "resources": { 7 | "requests": { 8 | "cpu": "1", 9 | "memory": "16Gi" 10 | } 11 | } 12 | } 13 | -------------------------------------------------------------------------------- /.skiff/scoracle.skiff.json: -------------------------------------------------------------------------------- 1 | { 2 | "appName": "scoracle", 3 | "contact": "ronanlb", 4 | "team": "aristo", 5 | "args": [ "demo", "scoracle", "--port", "8000", "--expose" ], 6 | "resources": { 7 | "requests": { 8 | "cpu": "1.0", 9 | "memory": "2Gi" 10 | } 11 | } 12 | } 13 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # Usage: A containerized environment for scruples. 2 | 3 | FROM nvidia/cuda:10.1-devel-ubuntu18.04 4 | 5 | 6 | # Configure language settings. 7 | 8 | ENV LC_ALL=C.UTF-8 \ 9 | LANG=C.UTF-8 10 | 11 | 12 | # Create the working directory. 13 | 14 | ENV APP_DIR=/app 15 | 16 | WORKDIR $APP_DIR 17 | 18 | 19 | # Install the OS packages. 20 | 21 | RUN apt-get update --fix-missing \ 22 | && DEBIAN_FRONTEND=noninteractive \ 23 | apt-get install --no-install-recommends --yes \ 24 | build-essential \ 25 | curl \ 26 | git \ 27 | libbz2-dev \ 28 | libffi-dev \ 29 | liblzma-dev \ 30 | libmysqlclient-dev \ 31 | libncurses5-dev \ 32 | libreadline-dev \ 33 | libsqlite3-dev \ 34 | libssl-dev \ 35 | libxml2-dev \ 36 | libxmlsec1-dev \ 37 | llvm \ 38 | make \ 39 | tk-dev \ 40 | wget \ 41 | xz-utils \ 42 | zlib1g-dev 43 | 44 | 45 | # Install Python. 46 | 47 | # Set environment variables for pyenv. 48 | 49 | ENV PYTHON_VERSION="3.7.0" \ 50 | PYENV_ROOT="$APP_DIR/.pyenv" 51 | 52 | ENV PATH="$PYENV_ROOT/bin:$PATH" 53 | 54 | # Install pyenv version 1.2.15. 55 | 56 | RUN git clone https://github.com/pyenv/pyenv.git $PYENV_ROOT \ 57 | && cd $PYENV_ROOT \ 58 | && git checkout b5f69fe58a58f3d2d59511c4630c96aafe5ad2b0 59 | 60 | # Install the python version. 61 | 62 | RUN pyenv install $PYTHON_VERSION 63 | 64 | # Place the new python version on the path. 65 | 66 | ENV PATH="$PYENV_ROOT/versions/$PYTHON_VERSION/bin:$PATH" 67 | 68 | 69 | # Install python package dependencies. 70 | 71 | COPY requirements.txt . 72 | 73 | RUN pip install --upgrade pip \ 74 | && pip install --requirement requirements.txt 75 | 76 | 77 | # Download the spacy models. 78 | 79 | RUN python -m spacy download en 80 | 81 | 82 | # Setup the norms demo. 83 | 84 | # Download the models. 85 | 86 | RUN mkdir --parents models/norms 87 | 88 | RUN mkdir models/norms/actions \ 89 | && wget \ 90 | --quiet \ 91 | --directory-prefix models/norms/actions \ 92 | https://storage.googleapis.com/ai2-mosaic-public/projects/scruples/v1.0/demos/norms/actions/config.json \ 93 | https://storage.googleapis.com/ai2-mosaic-public/projects/scruples/v1.0/demos/norms/actions/pytorch_model.bin 94 | 95 | RUN mkdir models/norms/corpus \ 96 | && wget \ 97 | --quiet \ 98 | --directory-prefix models/norms/corpus \ 99 | https://storage.googleapis.com/ai2-mosaic-public/projects/scruples/v1.0/demos/norms/corpus/config.json \ 100 | https://storage.googleapis.com/ai2-mosaic-public/projects/scruples/v1.0/demos/norms/corpus/pytorch_model.bin 101 | 102 | # Set the environment variables. 103 | 104 | ENV SCRUPLES_NORMS_ACTIONS_BASELINE="roberta" \ 105 | SCRUPLES_NORMS_ACTIONS_MODEL="/app/models/norms/actions" \ 106 | SCRUPLES_NORMS_CORPUS_BASELINE="roberta" \ 107 | SCRUPLES_NORMS_CORPUS_MODEL="/app/models/norms/corpus" \ 108 | SCRUPLES_NORMS_PREDICT_BATCH_SIZE="2" \ 109 | SCRUPLES_NORMS_GPU_IDS="" 110 | 111 | 112 | # Copy over the source code 113 | 114 | COPY . . 115 | 116 | 117 | # Install the scruples package. 118 | 119 | RUN pip install . 120 | 121 | 122 | # Define the entrypoint and default command. 123 | 124 | ENTRYPOINT [ "scruples" ] 125 | 126 | CMD [ "demo", "norms", "--port", "8000", "--expose" ] 127 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include readme.md 2 | include docs/*.md 3 | include src/scruples/demos/norms/templates/* 4 | include src/scruples/demos/norms/static/* 5 | include src/scruples/demos/scoracle/templates/* 6 | include src/scruples/demos/scoracle/static/* 7 | include tests/fixtures/* -------------------------------------------------------------------------------- /bin/scruples: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | 3 | """The command line interface for scruples.""" 4 | 5 | from scruples import scripts 6 | 7 | 8 | if __name__ == '__main__': 9 | scripts.scruples() 10 | -------------------------------------------------------------------------------- /docs/annotation-guidelines.md: -------------------------------------------------------------------------------- 1 | Annotation Guidelines 2 | ===================== 3 | Guidelines for annotating the ground truth for the data. 4 | 5 | These guidelines define how to annotate the ground truth for the reddit 6 | data in order to produce extraction quality evaluations. There are two 7 | kinds of annotations used in evaluating the extractions: comment and 8 | post annotations. 9 | 10 | 11 | Comment Annotations 12 | ------------------- 13 | Comment annotations should be written in a JSON Lines file where each 14 | object has the following keys: 15 | 16 |
17 |
id
18 |
19 | the ID attribute of the corresponding comment 20 |
21 |
label
22 |
23 | a gold annotation for the label (one of "AUTHOR", 24 | "OTHER", "EVERYBODY", 25 | "NOBODY", or "INFO") expressed by the 26 | comment, or null if no label is expressed 27 |
28 |
implied
29 |
30 | true if the label is implied by the view of the author 31 | and false if the label is somehow explicitly stated 32 |
33 |
spam
34 |
35 | true if the comment is spam, false 36 | otherwise 37 |
38 |
39 | 40 | The possible labels are: 41 | 42 |
43 |
AUTHOR
44 |
45 | The author of the anecdote is in the wrong. 46 |
47 |
OTHER
48 |
49 | The other person in the anecdote is in the wrong. 50 |
51 |
EVERYBODY
52 |
53 | Everyone in the anecdote is in the wrong. 54 |
55 |
NOBODY
56 |
57 | No one in the anecdote is in the wrong. 58 |
59 |
INFO
60 |
61 | More information is required to make a judgment. 62 |
63 |
64 | 65 | If the comment explicitly expresses a label either by its initialism or 66 | some phrase corresponding to the initialism, then use that label for the 67 | comment. Similarly, mark the comment with `implied` as `false` and 68 | `spam` as `false`. 69 | 70 | If the comment expresses multiple labels with no clear winner or is 71 | otherwise ambiguous, mark `label` as `null`, `implied` as `null`, and 72 | `spam` as `true`. 73 | 74 | If the comment expresses no labels explicitly but still has a viewpoint 75 | that clearly expresses one of the labels, then use that label for the 76 | comment. Mark `implied` as `true` and `spam` as `false`. 77 | 78 | Finally, if the comment expresses no label (i.e., none of `AUTHOR`, 79 | `OTHER`, `NOBODY`, `EVERYBODY`, or `INFO`), then mark `label` as `null`, 80 | `implied` as `null`, and `spam` as `true`. 81 | 82 | 83 | Post Annotations 84 | ---------------- 85 | Post annotations should be written in a JSON Lines file where each 86 | object has the following keys: 87 | 88 |
89 |
id
90 |
91 | the ID attribute of the corresponding post 92 |
93 |
post_type
94 |
95 | a gold annotation for the post's type 96 |
97 |
implied
98 |
99 | true if the post type is not explicitly stated in the 100 | post title. 101 |
102 |
spam
103 |
104 | true if the post is spam, false otherwise 105 |
106 |
107 | 108 | Possible post types are: 109 | 110 |
111 |
HISTORICAL
112 |
The author is asking "am I the a**hole?"
113 |
HYPOTHETICAL
114 |
The author is asking "would I be the a**hole?"
115 |
META
116 |
The post is about the subreddit itself.
117 |
118 | 119 | If the post type is explicitly stated in the post title, then mark 120 | `post_type` as the stated post type, mark `implied` as `false`, and 121 | `spam` as `false`, unless the post type is `META` in which case mark 122 | spam as `true`. Additionally, if the post type is explicitly stated but 123 | clearly wrong (such as using HISTORICAL for a HYPOTHETICAL post), then 124 | use the true post type rather than the stated one. 125 | 126 | If the post type is not explicitly stated in the post title, but 127 | otherwise clear from the post, mark the appropriate post type, mark 128 | `implied` as `true` and `spam` as `false`. 129 | 130 | If the post cannot be categorized into one of the types above, mark the 131 | `post_type` as `null`, `implied` as `null`, and `spam` as `true`. 132 | 133 | If the post is something that should not be present in the dataset (for 134 | example a deleted post), then mark `spam` as `true`. 135 | -------------------------------------------------------------------------------- /docs/assets/demo-screenshots/norms-about.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/scruples/c79697e97d24f5b43b5ff030a9b16011a32e4a5f/docs/assets/demo-screenshots/norms-about.png -------------------------------------------------------------------------------- /docs/assets/demo-screenshots/norms-results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/scruples/c79697e97d24f5b43b5ff030a9b16011a32e4a5f/docs/assets/demo-screenshots/norms-results.png -------------------------------------------------------------------------------- /docs/assets/demo-screenshots/scoracle-about.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/scruples/c79697e97d24f5b43b5ff030a9b16011a32e4a5f/docs/assets/demo-screenshots/scoracle-about.png -------------------------------------------------------------------------------- /docs/assets/demo-screenshots/scoracle-results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/scruples/c79697e97d24f5b43b5ff030a9b16011a32e4a5f/docs/assets/demo-screenshots/scoracle-results.png -------------------------------------------------------------------------------- /docs/demos.md: -------------------------------------------------------------------------------- 1 | Demos 2 | ===== 3 | Scruples ships with two demos, [scoracle](#scoracle) and [norms](#norms). You 4 | can visit these demos live on the web, or you can run them yourself. 5 | 6 | 7 | Setup 8 | ----- 9 | To run either of the demos, first you need to follow the 10 | [Setup](../readme.md#setup) and [Quickstart](../readme.md#quickstart) 11 | documentation in the [readme](../readme.md). 12 | 13 | 14 | Scoracle 15 | -------- 16 | To run scoracle, use the `scruples demo scoracle` command. The command is 17 | self-documenting: 18 | 19 | $ scruples demo scoracle --help 20 | Usage: scruples demo scoracle [OPTIONS] 21 | 22 | Serve the scoracle demo. 23 | 24 | Options: 25 | --expose Whether to expose the server to the internet, or run on 26 | localhost. 27 | --port INTEGER The port on which to serve the demo. Defaults to 5000. 28 | --help Show this message and exit. 29 | 30 | So, to run scoracle on localhost at port 5000, execute: 31 | 32 | scruples demo scoracle 33 | 34 | When you visit the site, you should see something like this: 35 | 36 | ![Scoracle Demo About Screenshot](./assets/demo-screenshots/scoracle-about.png?raw=true "Scoracle Demo About") 37 | 38 | After submitting a dataset and requesting the BEST performance for some 39 | metrics, you'll see something like this: 40 | 41 | ![Scoracle Demo Results Screenshot](./assets/demo-screenshots/scoracle-results.png?raw=true "Scoracle Demo Results") 42 | 43 | 44 | Norms 45 | ----- 46 | Before running the `norms` demo, you'll have to download the config files and 47 | weights: 48 | 49 | - **Scruples Anecdotes demo model**: [config][anecdotes-demo-config] 50 | [weights][anecdotes-demo-weights] 51 | - **Scruples Dilemmas demo model**: [config][dilemmas-demo-config] 52 | [weights][dilemmas-demo-weights] 53 | 54 | Once you've obtained the weights, set the following environment variables: 55 | 56 | SCRUPLES_NORMS_ACTIONS_BASELINE=roberta 57 | SCRUPLES_NORMS_ACTIONS_MODEL=$DILEMMAS_MODEL_PATH 58 | SCRUPLES_NORMS_CORPUS_BASELINE=roberta 59 | SCRUPLES_NORMS_CORPUS_MODEL=$ANECDOTES_MODEL_PATH 60 | SCRUPLES_NORMS_PREDICT_BATCH_SIZE=$BATCH_SIZE 61 | SCRUPLES_NORMS_GPU_IDS=$GPU_IDS 62 | 63 | For each of the models, the path to it should be a directory containing both 64 | the config and the weights file. `$GPU_IDS` should be a comma separated string 65 | of integers corresponding to which GPUs to use. You can set 66 | `SCRUPLES_NORMS_GPU_IDS` to the empty string to use the CPU instead. Start out 67 | with a value of 1 or 2 for `$BATCH_SIZE`, increasing it if you need more 68 | performance and have hardware that can support larger batches. 69 | 70 | To run norms, use the `scruples demo norms` command. The command is 71 | self-documenting: 72 | 73 | $ scruples demo norms --help 74 | Usage: scruples demo norms [OPTIONS] 75 | 76 | Serve the norms demo. 77 | 78 | In order to run this server, you must set the following environment 79 | variables: 80 | 81 | SCRUPLES_NORMS_ACTIONS_BASELINE : The baseline to use for the resource 82 | (actions). 83 | SCRUPLES_NORMS_ACTIONS_MODEL : The path to the saved pretrained 84 | model to use for predicting the actions. 85 | SCRUPLES_NORMS_CORPUS_BASELINE : The baseline to use for the corpus. 86 | SCRUPLES_NORMS_CORPUS_MODEL : The path to the saved pretrained 87 | model to use for predicting the corpus. 88 | SCRUPLES_NORMS_PREDICT_BATCH_SIZE : The batch size to use for 89 | prediction. 90 | SCRUPLES_NORMS_GPU_IDS : A comma separated list of GPU IDs to 91 | use. If none are provided, then the CPU will be used instead. 92 | 93 | Options: 94 | --expose Whether to expose the server to the internet, or run on 95 | localhost. 96 | --port INTEGER The port on which to serve the demo. Defaults to 5050. 97 | --help Show this message and exit. 98 | 99 | So, to run norms on localhost at port 5050, execute: 100 | 101 | scruples demo norms 102 | 103 | When you visit the site, you should see something like this: 104 | 105 | ![Norms Demo About Screenshot](./assets/demo-screenshots/norms-about.png?raw=true "Norms Demo About") 106 | 107 | And after submitting a dilemma to the model, you should see results like this: 108 | 109 | ![Norms Demo Results Screenshot](./assets/demo-screenshots/norms-results.png?raw=true "Norms Demo Results") 110 | 111 | 112 | [anecdotes-demo-config]: https://storage.googleapis.com/ai2-mosaic-public/projects/scruples/v1.0/demos/norms/corpus/config.json 113 | [anecdotes-demo-weights]: https://storage.googleapis.com/ai2-mosaic-public/projects/scruples/v1.0/demos/norms/corpus/pytorch_model.bin 114 | [dilemmas-demo-config]: https://storage.googleapis.com/ai2-mosaic-public/projects/scruples/v1.0/demos/norms/actions/config.json 115 | [dilemmas-demo-weights]: https://storage.googleapis.com/ai2-mosaic-public/projects/scruples/v1.0/demos/norms/actions/pytorch_model.bin 116 | -------------------------------------------------------------------------------- /hit-definitions/resource-labeling/data.jsonl: -------------------------------------------------------------------------------- 1 | {"instances":[{"id":"gVJWzXwANx0nDiJfBNBX8XCXUrmY5iiS","actions":[{"id":"W0zf9dPGrSCeFY2LRbAOknRppO24JkKM","description":"not encouraging my 6yo to have a relationship with her bio dad and 4 half siblings"},{"id":"f3qwv4n2oHAe09qBumna9K2HSu5u2jjS","description":"poking fun at my wife for getting off her diet"}],"label":0},{"id":"TxZG9YEwh575rli6L5P08zRaqMq8WrpH","actions":[{"id":"56s93ij0Biuaro7p0cgkbBeP22FUugjT","description":"using the bathroom together with my boyfriend despite our roommate asking us not to because it makes her uncomfortable"},{"id":"RMTVlCSVNNrPKl19u7D1yICikSPExskv","description":"saying \"who cares\" in a topic about a dead guy with CANCER"}],"label":0},{"id":"Y6SUccZCkoH2R4Pese1nza7HsXlrb32G","actions":[{"id":"XIvbvXQTL9cgp6Sz3TvUlg0Yzvt3r9mE","description":"leaving spoiled food for a food thief to steal"},{"id":"lrbFHh9LucZrDHhqVpoCTMQnUDHwO8bN","description":"filing a charge back against an artist who got my character wrong"}],"label":0},{"id":"wxDLCFx1do97axndP0hCeEkNtq5q2N7Z","actions":[{"id":"dCt8eqx9DORDvKd7drSFQbALuPTlLcIC","description":"saying you can't 'accidentally' sleep with someone"},{"id":"IJol2KEyplmWNVp9ChYr3wMeKZg7vELC","description":"slapping the shit out of a good friend"}],"label":0},{"id":"h2lsDLQBh1QePOgEtdIxnFqcueqasY67","actions":[{"id":"aPMp62Pl181ymGnxuHAssmUDfevDTgiD","description":"being pissed off at my metalhead co-worker"},{"id":"fkyecpgn2jjtlavGVoZAUIMq33rgXGdo","description":"pursuing my friend even after she's said no"}],"label":0},{"id":"SaxKK3kqIz1GhvcbvTUZ9i6sJdAsQcpL","actions":[{"id":"6qpRWTv5cS7gvswu9Sym8fnTW83SDVwE","description":"switching obese people's drink at McDonald's"},{"id":"ltf6CjS90dybzOfpmS1fjMi0C9bovbOG","description":"telling girl I've been seeing that each gender has its own advantages"}],"label":1},{"id":"kyMVRMbHCPLIwXMSenMlQxkDQ8VbejPp","actions":[{"id":"ci7JOi6a1yUZXLHGWZVQEBjEGG3ONSml","description":"rejecting a 'Suicidal' girl"},{"id":"K7pIEoSUkcKA3TGsmTZmHgaIiMN8ZWia","description":"considering my alcoholic husband's (sober) lies a separate \"choice\" from the disease itself"}],"label":0},{"id":"QXPYHAc3xZiQkO5L5VGqPyIK7lxYm5Ak","actions":[{"id":"TR35CUVusNNRR9PatcbIADCYiwwmPvfy","description":"criticizing my girlfriend's excessive Instagram browsing when I myself browse Reddit the same amount"},{"id":"tbbmaVTDLKIjvFYsBn9kotQSVouL8yto","description":"\"wanting to speak to the manager\""}],"label":1},{"id":"aw3pyzwwCI4x30EcOQQ5ARzjEhhOFxSG","actions":[{"id":"zWb9ZYmHdsN37JPKrgAUQrAQZZXqruwP","description":"leaving my boyfriend after he got belligerently drunk and pissed in the floor"},{"id":"6pQlP6mvYvBZP12g3G1EQpXQFn9msuT3","description":"yelling at some kids for not including my disabled daughter in play"}],"label":0},{"id":"AlgukffQBwubC84m3qnXC52XutCwTz8M","actions":[{"id":"PV2bVUdxZwDkePyRyAjMx1vy4XYBIy8q","description":"losing a friend who moved overseas"},{"id":"YyRNoQO6W4Oqo6NJyZ5PmAmDSYVIqKVs","description":"not wanting to continue pregnancy"}],"label":1},{"id":"Cc6kKRjf0D42XpmAvwiKqyiOEfAcBEs8","actions":[{"id":"TqkPxBhuSpR6IHq1S5Ah1VzpY4F52JLD","description":"possibly contributing to a classmate developing severe​ social anxiety and dropping out of school"},{"id":"IZ8duK1YfDHiClkK6TXrJvOsmLewuGbO","description":"reminding my gf about the meals I paid for"}],"label":1},{"id":"T7ejY85ym1RqsSTXmuyEFOSUsd4g4wIP","actions":[{"id":"94a5I6MFvMfHWDMM0H2jL8UaR0qV9iwM","description":"purposely one upping my broke brother with gifts for his kids"},{"id":"G0khTLzOFpLLfKW3wzjkdJ6fkWFykqMo","description":"snapping at a girl who has a short-ish temper"}],"label":1},{"id":"bYpGXBrdKy8UWLMECeD8JeU2FAwuN6Ui","actions":[{"id":"9LasoY9QCCBLA4zFRlqtQyOaDTFx6faJ","description":"calling the RA on a girl in the bathroom"},{"id":"IPvfLdTjrGbRjweSWCGeOCLAGkGuO0Nu","description":"recognizing Old Enemy as Juror"}],"label":0},{"id":"XRuFXoM3p6mlNa8Fn6ZOdXLRAxf8Tf5c","actions":[{"id":"hWZwVlcawqGxkIpHwJ40GHpVlKRg0S1g","description":"getting mad at my friend for stealing my phone and going into the bathroom with it"},{"id":"i9G2t1gtJwY5ddAzn4T4Pctv6spJb2Su","description":"batting away a lady's hand when she tried to touch my hair"}],"label":1},{"id":"zWuZ0WpFN44oReu4hQ7FN8ILSUzV1ChY","actions":[{"id":"7zeCf23MmqSrphChNgQb6Od9Hwjgj8Q4","description":"not wanting to work a day I asked for off a month ago"},{"id":"eruzElCLr6s9RwN1FrBWHqkYAq1FjEVS","description":"keying a car that was parked in my assigned spot"}],"label":0},{"id":"O2Doeu6A1oVgPjxqTdbPwH6wmdUcxnfm","actions":[{"id":"GYxsQrHIjDPn7a355UZWD0tMoXpRp2hJ","description":"becoming best friends with a friend's crush and now he's having a hard time because he feels I'm a better match for her than he is"},{"id":"SNYOv3kG67s5RR5IXnIqZ2IHq7t32eXv","description":"calling the police/cps on my SIL"}],"label":1},{"id":"Pot8x0BLD9KpbmVsUQOgSe3rKvlgfItX","actions":[{"id":"7KdzoC73NCI3t4TU0OfbfxMB5OspcFbn","description":"asking my neighbor to make sure their kids do not shoot fireworks so close to the car"},{"id":"QSxaCCE2JAUylq18OXRQb4UShKLBgsRR","description":"not going to my best friend and housemates hen party"}],"label":1},{"id":"Nx8f7BpokrNqOwZC4WbpVk2eOUOYPfvc","actions":[{"id":"i0bHuPXMTd3fmlpaTEcWWKmnM77dOpjv","description":"telling my ex-girlfriend's sister that she would never see my kids because she is pornstar"},{"id":"m0WHBUD8zXHYLVJiQsUCUW9xOrSF78wm","description":"being upset with my wife about a death in our family"}],"label":1},{"id":"gq9MOCBomD9o1KdQQhZqOISwAXKvq2SE","actions":[{"id":"QgW7pF3Soh3qo7JpWVclEuuehX9kuXnN","description":"blocking our tenant from the driveway"},{"id":"3fQx0WkNAgNshQwHvqch98AGKEcbpAio","description":"breaking up with my so who revealed they're trans"}],"label":1},{"id":"8gjhpGjfo1hmDiEdInpM06Ui5Bu73X9A","actions":[{"id":"dKwpB2xvbgr35SWdg7ibkQQQUhPmFqJ4","description":"telling my wife I won't try to have kids until she starts trying to take care of herself"},{"id":"9uJTZl1VhVFZKoNlfZ3ZPlNesbb4mlTn","description":"taking food off my friend's plate without asking first"}],"label":0}]} 2 | -------------------------------------------------------------------------------- /hit-definitions/resource-labeling/definition/NOTES: -------------------------------------------------------------------------------- 1 | This HIT asks people to label instances from the scruples resource. 2 | -------------------------------------------------------------------------------- /hit-definitions/resource-labeling/definition/hitproperties.json: -------------------------------------------------------------------------------- 1 | { 2 | "MaxAssignments": 10, 3 | "LifetimeInSeconds": 259200 4 | } 5 | -------------------------------------------------------------------------------- /hit-definitions/resource-labeling/definition/hittypeproperties.json: -------------------------------------------------------------------------------- 1 | { 2 | "AutoApprovalDelayInSeconds": 7200, 3 | "AssignmentDurationInSeconds": 7200, 4 | "Reward": "0.60", 5 | "Title": "What Would Be Worse? (WARNING: adult content / situations)", 6 | "Keywords": "artificial intelligence common sense commonsense ethics", 7 | "Description": "Pick the action that would be worse for someone to do.", 8 | "QualificationRequirements": [ 9 | { 10 | "QualificationTypeId": "00000000000000000040", 11 | "Comparator": "GreaterThanOrEqualTo", 12 | "IntegerValues": [ 13 | 5000 14 | ], 15 | "RequiredToPreview": false 16 | }, 17 | { 18 | "QualificationTypeId": "00000000000000000071", 19 | "Comparator": "In", 20 | "LocaleValues": [{"Country": "US"}, {"Country": "CA"}], 21 | "RequiredToPreview": false 22 | }, 23 | { 24 | "QualificationTypeId": "000000000000000000L0", 25 | "Comparator": "GreaterThanOrEqualTo", 26 | "IntegerValues": [ 27 | 99 28 | ], 29 | "RequiredToPreview": false 30 | }, 31 | { 32 | "QualificationTypeId": "00000000000000000060", 33 | "Comparator": "EqualTo", 34 | "IntegerValues": [1], 35 | "RequiredToPreview": true 36 | } 37 | ] 38 | } 39 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.2.0 2 | astunparse==1.6.3 3 | atomicwrites==1.4.1 4 | attrs==19.1.0 5 | autocommand==2.2.1 6 | autograd==1.2 7 | backports.csv==1.0.7 8 | beautifulsoup4==4.11.1 9 | blis==0.2.4 10 | boto3==1.24.82 11 | botocore==1.27.82 12 | cachetools==5.2.0 13 | certifi==2022.9.24 14 | cffi==1.15.1 15 | charset-normalizer==2.1.1 16 | cheroot==8.6.0 17 | CherryPy==18.8.0 18 | Click==7.0 19 | cryptography==38.0.1 20 | cycler==0.11.0 21 | cymem==2.0.6 22 | decorator==5.1.1 23 | dill==0.2.9 24 | feedparser==6.0.10 25 | Flask==1.1.1 26 | flatbuffers==22.9.24 27 | ftfy==5.5.1 28 | future==0.18.2 29 | gast==0.4.0 30 | gevent==21.12.0 31 | google-auth==2.12.0 32 | google-auth-oauthlib==0.4.6 33 | google-pasta==0.2.0 34 | greenlet==1.1.3 35 | grpcio==1.49.1 36 | h5py==3.7.0 37 | idna==3.4 38 | importlib-metadata==4.12.0 39 | importlib-resources==5.9.0 40 | inflect==6.0.0 41 | itsdangerous==1.1.0 42 | jaraco.classes==3.2.3 43 | jaraco.collections==3.5.2 44 | jaraco.context==4.1.2 45 | jaraco.functools==3.5.2 46 | jaraco.text==3.9.1 47 | Jinja2==2.11.3 48 | jmespath==1.0.1 49 | joblib==1.2.0 50 | jsonschema==2.6.0 51 | keras==2.10.0 52 | Keras-Preprocessing==1.1.2 53 | kiwisolver==1.4.4 54 | libclang==14.0.6 55 | lxml==4.9.1 56 | Markdown==3.4.1 57 | MarkupSafe==1.1.1 58 | matplotlib==3.0.3 59 | more-itertools==8.14.0 60 | murmurhash==1.0.8 61 | mysqlclient==2.1.1 62 | networkx==2.3 63 | nltk==3.6.3 64 | numpy==1.21.6 65 | oauthlib==3.2.1 66 | opt-einsum==3.3.0 67 | packaging==21.3 68 | pandas==0.24.2 69 | patsy==0.5.2 70 | Pattern==3.6 71 | pdfminer.six==20220524 72 | Pillow==9.2.0 73 | plac==0.9.6 74 | pluggy==1.0.0 75 | portend==3.1.0 76 | preshed==2.0.1 77 | protobuf==3.19.5 78 | py==1.11.0 79 | pyasn1==0.4.8 80 | pyasn1-modules==0.2.8 81 | pycparser==2.21 82 | pydantic==1.10.2 83 | pyparsing==3.0.9 84 | pytest==4.4.1 85 | python-dateutil==2.8.0 86 | python-docx==0.8.11 87 | pytz==2022.2.1 88 | regex==2018.1.10 89 | requests==2.28.1 90 | requests-oauthlib==1.3.1 91 | rsa==4.9 92 | s3transfer==0.6.0 93 | sacremoses==0.0.53 94 | scikit-learn==0.20.3 95 | scikit-optimize==0.5.2 96 | scipy==1.2.1 97 | seaborn==0.9.0 98 | sentencepiece==0.1.97 99 | sgmllib3k==1.0.0 100 | six==1.16.0 101 | soupsieve==2.3.2.post1 102 | spacy==2.1.3 103 | srsly==1.0.5 104 | statsmodels==0.11.1 105 | tempora==5.0.2 106 | tensorboard==2.10.1 107 | tensorboard-data-server==0.6.1 108 | tensorboard-plugin-wit==1.8.1 109 | tensorboardX==1.6 110 | tensorflow==2.10.0 111 | tensorflow-estimator==2.10.0 112 | tensorflow-io-gcs-filesystem==0.27.0 113 | termcolor==2.0.1 114 | thinc==7.0.8 115 | torch==1.3.1 116 | torchvision==0.4.2 117 | tqdm==4.32.2 118 | transformers==2.1.1 119 | typing_extensions==4.3.0 120 | urllib3==1.26.12 121 | wasabi==0.10.1 122 | wcwidth==0.2.5 123 | Werkzeug==1.0.1 124 | wrapt==1.14.1 125 | xgboost==0.90 126 | zc.lockfile==2.0 127 | zipp==3.8.1 128 | zope.event==4.5.0 129 | zope.interface==5.4.0 130 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [aliases] 2 | test=pytest 3 | 4 | 5 | [tool:pytest] 6 | addopts = 7 | --strict 8 | -p no:warnings 9 | markers = 10 | slow -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """setup.py file for packaging ``scruples``.""" 2 | 3 | from setuptools import setup 4 | 5 | 6 | with open('readme.md', 'r') as readme_file: 7 | readme = readme_file.read() 8 | 9 | with open('requirements.txt', 'r') as requirements_file: 10 | requirements = requirements_file.readlines() 11 | 12 | 13 | setup( 14 | name='scruples', 15 | version='0.3.0', 16 | description='A corpus and code for understanding norms and subjectivity.', 17 | long_description=readme, 18 | url='https://github.com/allenai/scruples', 19 | author='Nicholas Lourie', 20 | author_email='nicholasl@allenai.org', 21 | keywords='scruples social norms artificial intelligence ai' 22 | ' machine learning ml', 23 | classifiers=[ 24 | 'Development Status :: 3 - Alpha', 25 | 'Programming Language :: Python :: 3.7', 26 | 'License :: OSI Approved :: Apache Software License', 27 | 'Intended Audience :: Science/Research', 28 | 'Topic :: Scientific/Engineering :: Artificial Intelligence' 29 | ], 30 | license='Apache', 31 | packages=[ 32 | 'scruples', 33 | 'scruples.analysis', 34 | 'scruples.baselines', 35 | 'scruples.data', 36 | 'scruples.dataset', 37 | 'scruples.demos', 38 | 'scruples.demos.norms', 39 | 'scruples.demos.scoracle', 40 | 'scruples.extraction', 41 | 'scruples.scripts', 42 | 'scruples.scripts.analyze', 43 | 'scruples.scripts.analyze.corpus', 44 | 'scruples.scripts.analyze.resource', 45 | 'scruples.scripts.demo', 46 | 'scruples.scripts.evaluate', 47 | 'scruples.scripts.evaluate.corpus', 48 | 'scruples.scripts.evaluate.resource', 49 | 'scruples.scripts.make', 50 | 'scruples.scripts.make.resource', 51 | 'scruples.vendor' 52 | ], 53 | package_dir={'': 'src'}, 54 | scripts=['bin/scruples'], 55 | install_requires=[ 56 | 'Click >= 7.0', 57 | 'Flask >= 1.1.1', 58 | 'Pattern >= 3.6', 59 | 'attrs >= 19.1.0', 60 | 'autograd >= 1.2', 61 | 'dill >= 0.2.9', 62 | 'ftfy >= 5.5.1', 63 | 'gevent >= 1.4.0', 64 | 'itsdangerous < 2.0.0', 65 | 'Jinja2 < 3.0.0', 66 | 'MarkupSafe < 2.0.0', 67 | 'matplotlib >= 3.0.3', 68 | 'networkx >= 2.3', 69 | 'numpy >= 1.16.2', 70 | 'pandas >= 0.24.2', 71 | 'regex >= 2018.1.10', 72 | 'scikit-learn >= 0.20.3, < 0.21', 73 | 'scikit-optimize >= 0.5.2', 74 | 'scipy >= 1.2.1, < 1.3', 75 | 'seaborn >= 0.9.0', 76 | 'spacy >= 2.1.3, < 2.2.0', 77 | 'tensorboard >= 1.15.0', 78 | 'tensorboardX >= 1.6', 79 | 'tensorflow >= 1.15.2', 80 | 'tensorflow-estimator >= 1.15.1', 81 | 'torch >= 1.3.1', 82 | 'torchvision >= 0.4.2', 83 | 'transformers >= 2.1.1', 84 | 'Werkzeug < 2.0.0', 85 | 'xgboost >= 0.82' 86 | ], 87 | setup_requires=[ 88 | 'pytest-runner' 89 | ], 90 | tests_require=[ 91 | 'pytest' 92 | ], 93 | include_package_data=True, 94 | python_requires='>= 3.7', 95 | zip_safe=False) 96 | -------------------------------------------------------------------------------- /src/scruples/__init__.py: -------------------------------------------------------------------------------- 1 | """Build a dataset for detecting social norm violations.""" 2 | 3 | from . import ( 4 | analysis, 5 | baselines, 6 | data, 7 | dataset, 8 | extraction, 9 | scripts, 10 | settings, 11 | utils, 12 | vendor) 13 | -------------------------------------------------------------------------------- /src/scruples/analysis/__init__.py: -------------------------------------------------------------------------------- 1 | """Tools for conducting analyses.""" 2 | 3 | from . import traits 4 | -------------------------------------------------------------------------------- /src/scruples/baselines/__init__.py: -------------------------------------------------------------------------------- 1 | """Baselines for scruples.""" 2 | 3 | from . import ( 4 | bert, 5 | labels, 6 | linear, 7 | metrics, 8 | naivebayes, 9 | roberta, 10 | style, 11 | train, 12 | trees, 13 | utils) 14 | -------------------------------------------------------------------------------- /src/scruples/baselines/bert.py: -------------------------------------------------------------------------------- 1 | """BERT baselines.""" 2 | 3 | from transformers import ( 4 | BertForMultipleChoice, 5 | BertForSequenceClassification, 6 | BertTokenizer) 7 | import skopt 8 | import torch 9 | 10 | from ..data.labels import Label 11 | from ..dataset.transforms import ( 12 | BertTransform, 13 | Compose, 14 | Map) 15 | 16 | 17 | # the BERT sequence classification baseline 18 | 19 | BERTClassifier = BertForSequenceClassification.from_pretrained 20 | """Predict fixed classes with a fine-tuned BERT model.""" 21 | 22 | 23 | BERT_CLASSIFIER_CONFIG = { 24 | 'model': { 25 | # N.B. pretrained_model_name_or_path for the model must be the 26 | # same as pretrained_bert for the transform 27 | 'pretrained_model_name_or_path': 'bert-large-uncased', 28 | 'num_labels': len(Label) 29 | }, 30 | 'transform': { 31 | # N.B. pretrained_bert for the transform must be the same as 32 | # pretrained_model_name_or_path for the model 33 | 'pretrained_bert': 'bert-large-uncased', 34 | 'max_sequence_length': 512, 35 | 'truncation_strategy_title': 'beginning', 36 | 'truncation_strategy_text': 'beginning' 37 | } 38 | } 39 | """Configuration for ``BERTClassifier``.""" 40 | 41 | 42 | BERT_CLASSIFIER_HYPER_PARAM_SPACE = [ 43 | skopt.space.Real( 44 | low=1e-8, 45 | high=1e-2, 46 | prior='log-uniform', 47 | name='lr'), 48 | skopt.space.Real( 49 | low=1e-5, 50 | high=1e0, 51 | prior='log-uniform', 52 | name='weight_decay'), 53 | skopt.space.Real( 54 | low=0.0, 55 | high=1.0, 56 | prior='uniform', 57 | name='warmup_proportion'), 58 | skopt.space.Integer( 59 | low=1, 60 | high=10, 61 | name='n_epochs'), 62 | skopt.space.Integer( 63 | low=3, 64 | high=10, 65 | name='log_train_batch_size') 66 | ] 67 | """The hyper-param search space for ``BERTClassifier``.""" 68 | 69 | 70 | BERT_CLASSIFIER_TRANSFORM = ( 71 | lambda 72 | pretrained_bert, 73 | max_sequence_length, 74 | truncation_strategy_title, 75 | truncation_strategy_text: 76 | Compose([ 77 | BertTransform( 78 | tokenizer=BertTokenizer.from_pretrained( 79 | pretrained_bert, 80 | do_lower_case=pretrained_bert.endswith('-uncased')), 81 | max_sequence_length=max_sequence_length, 82 | truncation_strategy=( 83 | truncation_strategy_title, 84 | truncation_strategy_text 85 | )), 86 | lambda d: { 87 | 'input_ids': torch.tensor(d['input_ids']), 88 | 'attention_mask': torch.tensor(d['input_mask']), 89 | 'token_type_ids': torch.tensor(d['segment_ids']) 90 | } 91 | ]) 92 | ) 93 | """The factory to create data transforms for ``BERTClassifier``.""" 94 | 95 | 96 | # the BERT ranking baseline 97 | 98 | BERTRanker = BertForMultipleChoice.from_pretrained 99 | """Rank choices with a softmax over a fine-tuned BERT model.""" 100 | 101 | 102 | BERT_RANKER_CONFIG = { 103 | 'model': { 104 | # N.B. pretrained_model_name_or_path for the model must be the 105 | # same as pretrained_bert for the transform 106 | 'pretrained_model_name_or_path': 'bert-large-uncased' 107 | }, 108 | 'transform': { 109 | # N.B. pretrained_bert for the transform must be the same as 110 | # pretrained_model_name_or_path for the model 111 | 'pretrained_bert': 'bert-large-uncased', 112 | 'max_sequence_length': 92 113 | } 114 | } 115 | """Configuration for ``BERTRanker``.""" 116 | 117 | 118 | BERT_RANKER_HYPER_PARAM_SPACE = [ 119 | skopt.space.Real( 120 | low=1e-8, 121 | high=1e-2, 122 | prior='log-uniform', 123 | name='lr'), 124 | skopt.space.Real( 125 | low=1e-5, 126 | high=1e0, 127 | prior='log-uniform', 128 | name='weight_decay'), 129 | skopt.space.Real( 130 | low=0.0, 131 | high=1.0, 132 | prior='uniform', 133 | name='warmup_proportion'), 134 | skopt.space.Integer( 135 | low=1, 136 | high=25, 137 | name='n_epochs'), 138 | skopt.space.Integer( 139 | low=3, 140 | high=10, 141 | name='log_train_batch_size') 142 | ] 143 | """The hyper-param seach space for ``BERTRanker``.""" 144 | 145 | 146 | BERT_RANKER_TRANSFORM = ( 147 | lambda 148 | pretrained_bert, 149 | max_sequence_length: 150 | Compose([ 151 | # wrap each action in a tuple for passing it to BertTransform 152 | lambda actions: tuple((action, None) for action in actions), 153 | # map BertTransform across all the action choices 154 | Map( 155 | transform=BertTransform( 156 | tokenizer=BertTokenizer.from_pretrained( 157 | pretrained_bert, 158 | do_lower_case=pretrained_bert.endswith('-uncased')), 159 | max_sequence_length=max_sequence_length, 160 | truncation_strategy=('beginning', 'beginning'))), 161 | # collect the action choices and stack their tensors so the 162 | # choices can be their own dimension of the batch 163 | lambda ds: { 164 | 'input_ids': torch.stack([ 165 | torch.tensor(d['input_ids']) 166 | for d in ds 167 | ], dim=0), 168 | 'attention_mask': torch.stack([ 169 | torch.tensor(d['input_mask']) 170 | for d in ds 171 | ], dim=0), 172 | 'token_type_ids': torch.stack([ 173 | torch.tensor(d['segment_ids']) 174 | for d in ds 175 | ], dim=0) 176 | } 177 | ]) 178 | ) 179 | """The factory to create data transforms for ``BERTRanker``.""" 180 | -------------------------------------------------------------------------------- /src/scruples/baselines/corpus.py: -------------------------------------------------------------------------------- 1 | """Baseline configuration for the scruples corpus.""" 2 | 3 | from . import ( 4 | bert, 5 | labels, 6 | linear, 7 | naivebayes, 8 | roberta, 9 | style, 10 | trees) 11 | 12 | 13 | SHALLOW_BASELINES = { 14 | 'prior': ( 15 | labels.PriorBaseline, 16 | labels.PRIOR_HYPER_PARAMETERS 17 | ), 18 | 'stratified': ( 19 | labels.StratifiedBaseline, 20 | labels.STRATIFIED_HYPER_PARAMETERS 21 | ), 22 | 'logisticregression': ( 23 | linear.LogisticRegressionBaseline, 24 | linear.LOGISTIC_REGRESSION_HYPER_PARAMETERS 25 | ), 26 | 'bernoullinb': ( 27 | naivebayes.BernoulliNBBaseline, 28 | naivebayes.BERNOULLINB_HYPER_PARAMETERS 29 | ), 30 | 'multinomialnb': ( 31 | naivebayes.MultinomialNBBaseline, 32 | naivebayes.MULTINOMIALNB_HYPER_PARAMETERS 33 | ), 34 | 'complementnb': ( 35 | naivebayes.ComplementNBBaseline, 36 | naivebayes.COMPLEMENTNB_HYPER_PARAMETERS 37 | ), 38 | 'stylistic': ( 39 | style.StylisticXGBoostBaseline, 40 | style.STYLISTICXGBOOST_HYPER_PARAMETERS 41 | ), 42 | 'randomforest': ( 43 | trees.RandomForestBaseline, 44 | trees.RANDOM_FOREST_HYPER_PARAMETERS 45 | ) 46 | } 47 | """Shallow baseline models for the scruples corpus.""" 48 | 49 | 50 | FINE_TUNE_LM_BASELINES = { 51 | 'bert': ( 52 | bert.BERTClassifier, 53 | bert.BERT_CLASSIFIER_CONFIG, 54 | bert.BERT_CLASSIFIER_HYPER_PARAM_SPACE, 55 | bert.BERT_CLASSIFIER_TRANSFORM 56 | ), 57 | 'roberta': ( 58 | roberta.RoBERTaClassifier, 59 | roberta.ROBERTA_CLASSIFIER_CONFIG, 60 | roberta.ROBERTA_CLASSIFIER_HYPER_PARAM_SPACE, 61 | roberta.ROBERTA_CLASSIFIER_TRANSFORM 62 | ) 63 | } 64 | """Fine-tuned language model baselines for the scruples corpus.""" 65 | -------------------------------------------------------------------------------- /src/scruples/baselines/labels.py: -------------------------------------------------------------------------------- 1 | """Label only baselines.""" 2 | 3 | from sklearn.dummy import DummyClassifier 4 | from sklearn.pipeline import Pipeline 5 | 6 | 7 | # the class prior baseline 8 | 9 | PriorBaseline = Pipeline([ 10 | ('classifier', DummyClassifier(strategy='prior')) 11 | ]) 12 | """Predict using the label distribution.""" 13 | 14 | PRIOR_HYPER_PARAMETERS = {} 15 | """The hyper-param search space for ``PriorBaseline``.""" 16 | 17 | 18 | # the stratified sampling baseline 19 | 20 | StratifiedBaseline = Pipeline([ 21 | ('classifier', DummyClassifier(strategy='stratified')) 22 | ]) 23 | """Predict by sampling a class according to its probability.""" 24 | 25 | STRATIFIED_HYPER_PARAMETERS = {} 26 | """The hyper-param search space for ``StratifiedBaseline``.""" 27 | -------------------------------------------------------------------------------- /src/scruples/baselines/linear.py: -------------------------------------------------------------------------------- 1 | """Linear model baselines.""" 2 | 3 | from sklearn.feature_extraction.text import ( 4 | CountVectorizer, 5 | TfidfTransformer) 6 | from sklearn.linear_model import LogisticRegression 7 | from sklearn.pipeline import Pipeline 8 | from sklearn.preprocessing import FunctionTransformer, MaxAbsScaler 9 | 10 | from . import utils 11 | 12 | 13 | # the logistic regression baseline 14 | 15 | LogisticRegressionBaseline = Pipeline([ 16 | ( 17 | 'concatenator', 18 | FunctionTransformer(func=utils.concat_title_and_text, validate=False) 19 | ), 20 | ( 21 | 'vectorizer', 22 | CountVectorizer( 23 | input='content', 24 | encoding='utf-8', 25 | decode_error='strict', 26 | preprocessor=None, 27 | tokenizer=None, 28 | token_pattern=r'(?u)\b\w\w+\b', 29 | max_features=None, 30 | vocabulary=None) 31 | ), 32 | ( 33 | 'tfidf', 34 | TfidfTransformer(smooth_idf=True) 35 | ), 36 | ( 37 | 'scaler', 38 | MaxAbsScaler() 39 | ), 40 | ( 41 | 'classifier', 42 | LogisticRegression( 43 | penalty='l2', 44 | dual=False, 45 | tol=1e-4, 46 | intercept_scaling=1., 47 | solver='lbfgs', 48 | max_iter=100, 49 | warm_start=True) 50 | ) 51 | ]) 52 | """Predict using logistic regression on bag-of-ngrams features.""" 53 | 54 | LOGISTIC_REGRESSION_HYPER_PARAMETERS = { 55 | 'vectorizer__strip_accents': ['ascii', 'unicode', None], 56 | 'vectorizer__lowercase': [True, False], 57 | 'vectorizer__stop_words': ['english', None], 58 | 'vectorizer__ngram_range': [ 59 | (lo, hi) 60 | for lo in range(1, 2) 61 | for hi in range(lo, lo + 5) 62 | ], 63 | 'vectorizer__analyzer': ['word', 'char', 'char_wb'], 64 | 'vectorizer__max_df': (0.75, 1., 'uniform'), 65 | 'vectorizer__min_df': (0., 0.25, 'uniform'), 66 | 'vectorizer__binary': [True, False], 67 | 'tfidf__norm': ['l1', 'l2', None], 68 | 'tfidf__use_idf': [True, False], 69 | 'tfidf__sublinear_tf': [True, False], 70 | 'classifier__C': (1e-6, 1e2, 'log-uniform'), 71 | 'classifier__fit_intercept': [True, False], 72 | 'classifier__class_weight': ['balanced', None], 73 | 'classifier__multi_class': ['ovr', 'multinomial'] 74 | } 75 | """The hyper-param search space for ``LogisticRegressionBaseline``.""" 76 | 77 | 78 | # the logistic ranker baseline 79 | 80 | LogisticRankerBaseline = Pipeline([ 81 | ( 82 | 'featurizer', 83 | utils.ResourceTransformer( 84 | transformer=Pipeline([ 85 | ( 86 | 'vectorizer', 87 | CountVectorizer( 88 | input='content', 89 | encoding='utf-8', 90 | decode_error='featurizer', 91 | preprocessor=None, 92 | tokenizer=None, 93 | token_pattern=r'(?u)\b\w\w+\b', 94 | max_features=None, 95 | vocabulary=None) 96 | ), 97 | ( 98 | 'tfidf', 99 | TfidfTransformer(smooth_idf=True) 100 | ), 101 | ( 102 | 'scaler', 103 | MaxAbsScaler() 104 | ) 105 | ])) 106 | ), 107 | ( 108 | 'classifier', 109 | LogisticRegression( 110 | penalty='l2', 111 | dual=False, 112 | tol=1e-4, 113 | fit_intercept=False, 114 | intercept_scaling=1., 115 | solver='lbfgs', 116 | max_iter=100, 117 | warm_start=True) 118 | ) 119 | ]) 120 | """Rank using logistic regression on bag-of-ngrams features.""" 121 | 122 | LOGISTIC_RANKER_HYPER_PARAMETERS = { 123 | 'featurizer__transformer__vectorizer__strip_accents': ['ascii', 'unicode', None], 124 | 'featurizer__transformer__vectorizer__lowercase': [True, False], 125 | 'featurizer__transformer__vectorizer__stop_words': ['english', None], 126 | 'featurizer__transformer__vectorizer__ngram_range': [ 127 | (lo, hi) 128 | for lo in range(1, 2) 129 | for hi in range(lo, lo + 5) 130 | ], 131 | 'featurizer__transformer__vectorizer__analyzer': ['word', 'char', 'char_wb'], 132 | 'featurizer__transformer__vectorizer__max_df': (0.90, 1., 'uniform'), 133 | 'featurizer__transformer__vectorizer__min_df': (0., 0.10, 'uniform'), 134 | 'featurizer__transformer__vectorizer__binary': [True, False], 135 | 'featurizer__transformer__tfidf__norm': ['l1', 'l2', None], 136 | 'featurizer__transformer__tfidf__use_idf': [True, False], 137 | 'featurizer__transformer__tfidf__sublinear_tf': [True, False], 138 | 'classifier__C': (1e-6, 1e2, 'log-uniform'), 139 | 'classifier__class_weight': ['balanced', None], 140 | } 141 | """The hyper-param search space for ``LogisticRankerBaseline``.""" 142 | -------------------------------------------------------------------------------- /src/scruples/baselines/loss.py: -------------------------------------------------------------------------------- 1 | """Custom losses for baselines.""" 2 | 3 | import torch as th 4 | from torch.nn import functional as F 5 | 6 | 7 | class SoftCrossEntropyLoss(th.nn.Module): 8 | """Cross-entropy with soft reference labels.""" 9 | 10 | # N.B. This loss can be used in two ways. First, labels can be 11 | # averaged and then passed in as the targets, and second, the raw 12 | # counts of each label can be passed in as the target. The first 13 | # approach leverages the soft labels, but ignores relative certainty 14 | # of one instances labeling versus another. The second approach 15 | # leverages all of the annotation information available. Both are 16 | # equivalent if each instance has the same number of labels. 17 | def forward(self, input, target): 18 | return - th.mean( 19 | th.sum(target * F.log_softmax(input, dim=-1), dim=-1)) 20 | 21 | 22 | class DirichletMultinomialLoss(th.nn.Module): 23 | """Negative log-likelihood for a dirichlet-multinomial.""" 24 | 25 | # N.B. note that this function computes the likelihood of the 26 | # observed labels, and not the likelihood of the sufficient 27 | # statistic derived from them (i.e., the counts of each label). We 28 | # only need the sufficient statistic however to compute this 29 | # likelihood, and both lead to the same MLE. 30 | def forward(self, inputs, targets): 31 | inputs = th.exp(inputs) 32 | return - th.mean( 33 | th.lgamma(th.sum(inputs, dim=-1)) 34 | + th.sum(th.lgamma(inputs + targets), dim=-1) 35 | - th.lgamma(th.sum(inputs + targets, dim=-1)) 36 | - th.sum(th.lgamma(inputs), dim=-1)) 37 | -------------------------------------------------------------------------------- /src/scruples/baselines/metrics.py: -------------------------------------------------------------------------------- 1 | """Metrics for assessing baseline models.""" 2 | 3 | from sklearn import metrics 4 | 5 | 6 | METRICS = { 7 | 'accuracy': ( 8 | 'accuracy', 9 | metrics.accuracy_score, 10 | { 11 | 'greater_is_better': True, 12 | 'needs_proba': False 13 | } 14 | ), 15 | 'balanced_accuracy': ( 16 | 'balanced accuracy', 17 | metrics.balanced_accuracy_score, 18 | { 19 | 'greater_is_better': True, 20 | 'needs_proba': False 21 | } 22 | ), 23 | 'precision_micro': ( 24 | 'precision (micro)', 25 | lambda y_true, y_pred: metrics.precision_score( 26 | y_true=y_true, 27 | y_pred=y_pred, 28 | average='micro'), 29 | { 30 | 'greater_is_better': True, 31 | 'needs_proba': False 32 | } 33 | ), 34 | 'recall_micro': ( 35 | 'recall (micro)', 36 | lambda y_true, y_pred: metrics.recall_score( 37 | y_true=y_true, 38 | y_pred=y_pred, 39 | average='micro'), 40 | { 41 | 'greater_is_better': True, 42 | 'needs_proba': False 43 | } 44 | ), 45 | 'f1_micro': ( 46 | 'f1 (micro)', 47 | lambda y_true, y_pred: metrics.f1_score( 48 | y_true=y_true, 49 | y_pred=y_pred, 50 | average='micro'), 51 | { 52 | 'greater_is_better': True, 53 | 'needs_proba': False 54 | } 55 | ), 56 | 'precision_macro': ( 57 | 'precision (macro)', 58 | lambda y_true, y_pred: metrics.precision_score( 59 | y_true=y_true, 60 | y_pred=y_pred, 61 | average='macro'), 62 | { 63 | 'greater_is_better': True, 64 | 'needs_proba': False 65 | } 66 | ), 67 | 'recall_macro': ( 68 | 'recall (macro)', 69 | lambda y_true, y_pred: metrics.recall_score( 70 | y_true=y_true, 71 | y_pred=y_pred, 72 | average='macro'), 73 | { 74 | 'greater_is_better': True, 75 | 'needs_proba': False 76 | } 77 | ), 78 | 'f1_macro': ( 79 | 'f1 (macro)', 80 | lambda y_true, y_pred: metrics.f1_score( 81 | y_true=y_true, 82 | y_pred=y_pred, 83 | average='macro'), 84 | { 85 | 'greater_is_better': True, 86 | 'needs_proba': False 87 | } 88 | ), 89 | 'f1_weighted': ( 90 | 'f1 (weighted)', 91 | lambda y_true, y_pred: metrics.f1_score( 92 | y_true=y_true, 93 | y_pred=y_pred, 94 | average='weighted'), 95 | { 96 | 'greater_is_better': True, 97 | 'needs_proba': False 98 | } 99 | ), 100 | 'log_loss': ( 101 | 'log loss', 102 | lambda y_true, y_pred: metrics.log_loss( 103 | y_true=y_true, 104 | y_pred=y_pred, 105 | eps=1e-9), 106 | { 107 | 'greater_is_better': False, 108 | 'needs_proba': True 109 | } 110 | ), 111 | 'matthews_corrcoef': ( 112 | 'matthews correlation coefficient', 113 | metrics.matthews_corrcoef, 114 | { 115 | 'greater_is_better': True, 116 | 'needs_proba': False 117 | } 118 | ) 119 | # N.B., do not include a key for "xentropy" or "calibrated_xentropy" 120 | # in this dictionary. Those keys are reserved for the cross-entropy 121 | # between the predicted probabilities and the dataset's label 122 | # scores, which is computed in the scruples.scripts.analyze.*.predictions 123 | # scripts. 124 | } 125 | """A dictionary defining the important metrics to assess baselines. 126 | 127 | The dictionary maps metric names to ``(name, metric, scorer_kwargs)`` 128 | tuples. ``name`` is the name of the metric, while ``metric`` is a 129 | function for computing it, and ``scorer_kwargs`` is a dictionary 130 | containing two keys: ``"greater_is_better"``, a boolean defining whether 131 | or not higher values are better, and ``"needs_proba"``, a boolean 132 | defining whether to pass the predicted labels or the predicted 133 | probabilities. 134 | """ 135 | -------------------------------------------------------------------------------- /src/scruples/baselines/naivebayes.py: -------------------------------------------------------------------------------- 1 | """Naive bayes baselines.""" 2 | 3 | from sklearn.feature_extraction.text import CountVectorizer 4 | from sklearn.naive_bayes import ( 5 | BernoulliNB, 6 | MultinomialNB, 7 | ComplementNB) 8 | from sklearn.pipeline import Pipeline 9 | from sklearn.preprocessing import FunctionTransformer 10 | 11 | from . import utils 12 | 13 | 14 | # the bernoulli naive bayes baseline 15 | 16 | BernoulliNBBaseline = Pipeline([ 17 | ( 18 | 'concatenator', 19 | FunctionTransformer(func=utils.concat_title_and_text, validate=False) 20 | ), 21 | ( 22 | 'vectorizer', 23 | CountVectorizer( 24 | input='content', 25 | encoding='utf-8', 26 | decode_error='strict', 27 | preprocessor=None, 28 | tokenizer=None, 29 | token_pattern=r'(?u)\b\w\w+\b', 30 | max_features=None, 31 | vocabulary=None, 32 | binary=True) 33 | ), 34 | ( 35 | 'classifier', 36 | BernoulliNB(binarize=None, fit_prior=True) 37 | ) 38 | ]) 39 | """Predict using bernoulli naive bayes on bag-of-ngrams features.""" 40 | 41 | BERNOULLINB_HYPER_PARAMETERS = { 42 | 'vectorizer__strip_accents': ['ascii', 'unicode', None], 43 | 'vectorizer__lowercase': [True, False], 44 | 'vectorizer__stop_words': ['english', None], 45 | 'vectorizer__ngram_range': [ 46 | (lo, hi) 47 | for lo in range(1, 2) 48 | for hi in range(lo, lo + 5) 49 | ], 50 | 'vectorizer__analyzer': ['word', 'char', 'char_wb'], 51 | 'vectorizer__max_df': (0.75, 1., 'uniform'), 52 | 'vectorizer__min_df': (0., 0.25, 'uniform'), 53 | 'classifier__alpha': (0., 5., 'uniform') 54 | } 55 | """The hyper-param search space for ``BeroulliNBBaseline``.""" 56 | 57 | 58 | # the multinomial naive bayes baseline 59 | 60 | MultinomialNBBaseline = Pipeline([ 61 | ( 62 | 'concatenator', 63 | FunctionTransformer(func=utils.concat_title_and_text, validate=False) 64 | ), 65 | ( 66 | 'vectorizer', 67 | CountVectorizer( 68 | input='content', 69 | encoding='utf-8', 70 | decode_error='strict', 71 | preprocessor=None, 72 | tokenizer=None, 73 | token_pattern=r'(?u)\b\w\w+\b', 74 | max_features=None, 75 | vocabulary=None, 76 | binary=False) 77 | ), 78 | ( 79 | 'classifier', 80 | MultinomialNB(fit_prior=True) 81 | ) 82 | ]) 83 | """Predict using multinomial naive bayes on bag-of-ngrams features.""" 84 | 85 | MULTINOMIALNB_HYPER_PARAMETERS = { 86 | 'vectorizer__strip_accents': ['ascii', 'unicode', None], 87 | 'vectorizer__lowercase': [True, False], 88 | 'vectorizer__stop_words': ['english', None], 89 | 'vectorizer__ngram_range': [ 90 | (lo, hi) 91 | for lo in range(1, 2) 92 | for hi in range(lo, lo + 5) 93 | ], 94 | 'vectorizer__analyzer': ['word', 'char', 'char_wb'], 95 | 'vectorizer__max_df': (0.75, 1., 'uniform'), 96 | 'vectorizer__min_df': (0., 0.25, 'uniform'), 97 | 'classifier__alpha': (0., 5., 'uniform') 98 | } 99 | """The hyper-param search space for ``MultinomialNBBaseline``.""" 100 | 101 | 102 | # the complement naive bayes baseline 103 | 104 | ComplementNBBaseline = Pipeline([ 105 | ( 106 | 'concatenator', 107 | FunctionTransformer(func=utils.concat_title_and_text, validate=False) 108 | ), 109 | ( 110 | 'vectorizer', 111 | CountVectorizer( 112 | input='content', 113 | encoding='utf-8', 114 | decode_error='strict', 115 | preprocessor=None, 116 | tokenizer=None, 117 | token_pattern=r'(?u)\b\w\w+\b', 118 | max_features=None, 119 | vocabulary=None, 120 | binary=False) 121 | ), 122 | ( 123 | 'classifier', 124 | ComplementNB(fit_prior=True) 125 | ) 126 | ]) 127 | """Predict using complement naive bayes on bag-of-ngrams features.""" 128 | 129 | COMPLEMENTNB_HYPER_PARAMETERS = { 130 | 'vectorizer__strip_accents': ['ascii', 'unicode', None], 131 | 'vectorizer__lowercase': [True, False], 132 | 'vectorizer__stop_words': ['english', None], 133 | 'vectorizer__ngram_range': [ 134 | (lo, hi) 135 | for lo in range(1, 2) 136 | for hi in range(lo, lo + 5) 137 | ], 138 | 'vectorizer__analyzer': ['word', 'char', 'char_wb'], 139 | 'vectorizer__max_df': (0.75, 1., 'uniform'), 140 | 'vectorizer__min_df': (0., 0.25, 'uniform'), 141 | 'classifier__alpha': (0., 5., 'uniform'), 142 | 'classifier__norm': [True, False] 143 | } 144 | """The hyper-param search space for ``ComplementNBBaseline``.""" 145 | -------------------------------------------------------------------------------- /src/scruples/baselines/resource.py: -------------------------------------------------------------------------------- 1 | """Baseline configuration for the scruples resource.""" 2 | 3 | from . import ( 4 | bert, 5 | labels, 6 | linear, 7 | roberta, 8 | style) 9 | 10 | 11 | SHALLOW_BASELINES = { 12 | 'prior': ( 13 | labels.PriorBaseline, 14 | labels.PRIOR_HYPER_PARAMETERS 15 | ), 16 | 'stratified': ( 17 | labels.StratifiedBaseline, 18 | labels.STRATIFIED_HYPER_PARAMETERS 19 | ), 20 | 'fewestwords': ( 21 | style.FewestWordsBaseline, 22 | style.FEWEST_WORDS_HYPER_PARAMETERS 23 | ), 24 | 'mostwords': ( 25 | style.MostWordsBaseline, 26 | style.MOST_WORDS_HYPER_PARAMETERS 27 | ), 28 | 'fewestcharacters': ( 29 | style.FewestCharactersBaseline, 30 | style.FEWEST_CHARACTERS_HYPER_PARAMETERS 31 | ), 32 | 'mostcharacters': ( 33 | style.MostCharactersBaseline, 34 | style.MOST_CHARACTERS_HYPER_PARAMETERS 35 | ), 36 | 'logisticranker': ( 37 | linear.LogisticRankerBaseline, 38 | linear.LOGISTIC_RANKER_HYPER_PARAMETERS 39 | ), 40 | 'stylistic': ( 41 | style.StyleRankerBaseline, 42 | style.STYLE_RANKER_HYPER_PARAMETERS 43 | ) 44 | } 45 | """Shallow baseline models for the scruples resource.""" 46 | 47 | 48 | FINE_TUNE_LM_BASELINES = { 49 | 'bert': ( 50 | bert.BERTRanker, 51 | bert.BERT_RANKER_CONFIG, 52 | bert.BERT_RANKER_HYPER_PARAM_SPACE, 53 | bert.BERT_RANKER_TRANSFORM 54 | ), 55 | 'roberta': ( 56 | roberta.RoBERTaRanker, 57 | roberta.ROBERTA_RANKER_CONFIG, 58 | roberta.ROBERTA_RANKER_HYPER_PARAM_SPACE, 59 | roberta.ROBERTA_RANKER_TRANSFORM 60 | ) 61 | } 62 | """Fine-tuned language model baselines for the scruples resource.""" 63 | -------------------------------------------------------------------------------- /src/scruples/baselines/roberta.py: -------------------------------------------------------------------------------- 1 | """RoBERTa baselines.""" 2 | 3 | from transformers import ( 4 | RobertaForMultipleChoice, 5 | RobertaForSequenceClassification, 6 | RobertaTokenizer) 7 | import skopt 8 | import torch 9 | 10 | from ..data.labels import Label 11 | from ..dataset.transforms import ( 12 | BertTransform, 13 | Compose, 14 | Map) 15 | 16 | 17 | # the RoBERTa sequence classification baseline 18 | 19 | RoBERTaClassifier = RobertaForSequenceClassification.from_pretrained 20 | """Predict fixed classes with a fine-tuned RoBERTa model.""" 21 | 22 | 23 | ROBERTA_CLASSIFIER_CONFIG = { 24 | 'model': { 25 | # N.B. pretrained_model_name_or_path for the model must be the 26 | # same as pretrained_roberta for the transform 27 | 'pretrained_model_name_or_path': 'roberta-large', 28 | 'num_labels': len(Label) 29 | }, 30 | 'transform': { 31 | # N.B. pretrained_roberta for the transform must be the same as 32 | # pretrained_model_name_or_path for the model 33 | 'pretrained_roberta': 'roberta-large', 34 | 'max_sequence_length': 512, 35 | 'truncation_strategy_title': 'beginning', 36 | 'truncation_strategy_text': 'beginning' 37 | } 38 | } 39 | """Configuration for ``RoBERTaClassifier``.""" 40 | 41 | 42 | ROBERTA_CLASSIFIER_HYPER_PARAM_SPACE = [ 43 | skopt.space.Real( 44 | low=1e-8, 45 | high=1e-2, 46 | prior='log-uniform', 47 | name='lr'), 48 | skopt.space.Real( 49 | low=1e-5, 50 | high=1e0, 51 | prior='log-uniform', 52 | name='weight_decay'), 53 | skopt.space.Real( 54 | low=0.0, 55 | high=1.0, 56 | prior='uniform', 57 | name='warmup_proportion'), 58 | skopt.space.Integer( 59 | low=1, 60 | high=10, 61 | name='n_epochs'), 62 | skopt.space.Integer( 63 | low=3, 64 | high=10, 65 | name='log_train_batch_size') 66 | ] 67 | """The hyper-param search space for ``RoBERTaClassifier``.""" 68 | 69 | 70 | ROBERTA_CLASSIFIER_TRANSFORM = ( 71 | lambda 72 | pretrained_roberta, 73 | max_sequence_length, 74 | truncation_strategy_title, 75 | truncation_strategy_text: 76 | Compose([ 77 | BertTransform( 78 | tokenizer=RobertaTokenizer.from_pretrained( 79 | pretrained_roberta, 80 | do_lower_case=False), 81 | max_sequence_length=max_sequence_length, 82 | truncation_strategy=( 83 | truncation_strategy_title, 84 | truncation_strategy_text), 85 | starting_sep_token=True 86 | ), 87 | lambda d: { 88 | 'input_ids': torch.tensor(d['input_ids']), 89 | 'attention_mask': torch.tensor(d['input_mask']) 90 | } 91 | ]) 92 | ) 93 | """The factory to create data transforms for ``RoBERTaClassifier``.""" 94 | 95 | 96 | # the RoBERTa ranking baseline 97 | 98 | RoBERTaRanker = RobertaForMultipleChoice.from_pretrained 99 | """Rank choices with a softmax over a fine-tuned RoBERTa model.""" 100 | 101 | 102 | ROBERTA_RANKER_CONFIG = { 103 | 'model': { 104 | # N.B. pretrained_model_name_or_path for the model must be the 105 | # same as pretrained_roberta for the transform 106 | 'pretrained_model_name_or_path': 'roberta-large' 107 | }, 108 | 'transform': { 109 | # N.B. pretrained_roberta for the transform must be the same as 110 | # pretrained_model_name_or_path for the model 111 | 'pretrained_roberta': 'roberta-large', 112 | 'max_sequence_length': 90 113 | } 114 | } 115 | """Configuration for ``RoBERTaRanker``.""" 116 | 117 | 118 | ROBERTA_RANKER_HYPER_PARAM_SPACE = [ 119 | skopt.space.Real( 120 | low=1e-8, 121 | high=1e-2, 122 | prior='log-uniform', 123 | name='lr'), 124 | skopt.space.Real( 125 | low=1e-5, 126 | high=1e0, 127 | prior='log-uniform', 128 | name='weight_decay'), 129 | skopt.space.Real( 130 | low=0.0, 131 | high=1.0, 132 | prior='uniform', 133 | name='warmup_proportion'), 134 | skopt.space.Integer( 135 | low=1, 136 | high=25, 137 | name='n_epochs'), 138 | skopt.space.Integer( 139 | low=3, 140 | high=10, 141 | name='log_train_batch_size') 142 | ] 143 | """The hyper-param seach space for ``RoBERTaRanker``.""" 144 | 145 | 146 | ROBERTA_RANKER_TRANSFORM = ( 147 | lambda 148 | pretrained_roberta, 149 | max_sequence_length: 150 | Compose([ 151 | # wrap each action in a tuple for passing it to BertTransform 152 | lambda actions: tuple((action, None) for action in actions), 153 | # map BertTransform across all the action choices 154 | Map( 155 | transform=BertTransform( 156 | tokenizer=RobertaTokenizer.from_pretrained( 157 | pretrained_roberta, 158 | do_lower_case=False), 159 | max_sequence_length=max_sequence_length, 160 | truncation_strategy=('beginning', 'beginning'), 161 | starting_sep_token=True 162 | ) 163 | ), 164 | # collect the action choices and stack their tensors so the 165 | # choices can be their own dimension of the batch 166 | lambda ds: { 167 | 'input_ids': torch.stack([ 168 | torch.tensor(d['input_ids']) 169 | for d in ds 170 | ], dim=0), 171 | 'attention_mask': torch.stack([ 172 | torch.tensor(d['input_mask']) 173 | for d in ds 174 | ], dim=0) 175 | } 176 | ]) 177 | ) 178 | """The factory to create data transforms for ``RoBERTaRanker``.""" 179 | -------------------------------------------------------------------------------- /src/scruples/baselines/trees.py: -------------------------------------------------------------------------------- 1 | """Tree-based baselines.""" 2 | 3 | from sklearn.feature_extraction.text import ( 4 | CountVectorizer, 5 | TfidfTransformer) 6 | from sklearn.ensemble import RandomForestClassifier 7 | from sklearn.pipeline import Pipeline 8 | from sklearn.preprocessing import FunctionTransformer 9 | 10 | from . import utils 11 | 12 | 13 | # the random forest baseline 14 | 15 | RandomForestBaseline = Pipeline([ 16 | ( 17 | 'concatenator', 18 | FunctionTransformer(func=utils.concat_title_and_text, validate=False) 19 | ), 20 | ( 21 | 'vectorizer', 22 | CountVectorizer( 23 | input='content', 24 | encoding='utf-8', 25 | decode_error='strict', 26 | preprocessor=None, 27 | tokenizer=None, 28 | token_pattern=r'(?u)\b\w\w+\b', 29 | max_features=None, 30 | vocabulary=None) 31 | ), 32 | ( 33 | 'tfidf', 34 | TfidfTransformer(smooth_idf=True) 35 | ), 36 | ( 37 | 'classifier', 38 | RandomForestClassifier( 39 | n_estimators=100, 40 | max_depth=None, 41 | max_features='auto', 42 | max_leaf_nodes=None, 43 | min_impurity_decrease=0., 44 | oob_score=False, 45 | n_jobs=1, 46 | verbose=0, 47 | warm_start=False) 48 | ) 49 | ]) 50 | """Predict using a random forest on bag-of-ngrams features.""" 51 | 52 | RANDOM_FOREST_HYPER_PARAMETERS = { 53 | 'vectorizer__strip_accents': ['ascii', 'unicode', None], 54 | 'vectorizer__lowercase': [True, False], 55 | 'vectorizer__stop_words': ['english', None], 56 | 'vectorizer__ngram_range': [ 57 | (lo, hi) 58 | for lo in range(1, 2) 59 | for hi in range(lo, lo + 5) 60 | ], 61 | 'vectorizer__analyzer': ['word', 'char', 'char_wb'], 62 | 'vectorizer__max_df': (0.75, 1., 'uniform'), 63 | 'vectorizer__min_df': (0., 0.25, 'uniform'), 64 | 'vectorizer__binary': [True, False], 65 | 'tfidf__norm': ['l1', 'l2', None], 66 | 'tfidf__use_idf': [True, False], 67 | 'tfidf__sublinear_tf': [True, False], 68 | 'classifier__criterion': ['gini', 'entropy'], 69 | 'classifier__min_samples_split': (2, 500), 70 | 'classifier__min_samples_leaf': (1, 250), 71 | 'classifier__min_weight_fraction_leaf': (0., .25, 'uniform'), 72 | 'classifier__bootstrap': [True, False], 73 | 'classifier__class_weight': ['balanced', 'balanced_subsample', None] 74 | } 75 | """The hyper-param search space for ``RandomForestBaseline``.""" 76 | -------------------------------------------------------------------------------- /src/scruples/baselines/utils.py: -------------------------------------------------------------------------------- 1 | """Utilities for baselines on scruples.""" 2 | 3 | from typing import Any, Dict 4 | 5 | import numpy as np 6 | import pandas as pd 7 | from scipy.special import gammaln 8 | from sklearn.base import ( 9 | BaseEstimator, 10 | TransformerMixin) 11 | from sklearn.utils.validation import check_is_fitted 12 | 13 | 14 | # functions 15 | 16 | def concat_title_and_text(features: pd.DataFrame) -> np.ndarray: 17 | """Return the concatenation of the title and text features. 18 | 19 | Parameters 20 | ---------- 21 | features : pd.DataFrame 22 | The features for the scruples dataset. 23 | 24 | Returns 25 | ------- 26 | np.ndarray 27 | The concatentation of the title and text strings separated by a 28 | newline character, in a numpy array. 29 | """ 30 | return (features['title'] + '\n' + features['text']).values 31 | 32 | 33 | def dirichlet_multinomial(log_alphas: np.ndarray) -> np.ndarray: 34 | """Return class probabilities from a dirichlet-multinomial model. 35 | 36 | Parameters 37 | ---------- 38 | log_alphas : np.ndarray 39 | An n x k dimensional numpy array where n is the number of 40 | samples and k is the number of classes. The values of the array 41 | should correspond to the log of the alpha parameters for the 42 | predicted dirichlet distribution corresponding to each instance. 43 | 44 | Returns 45 | ------- 46 | np.ndarray 47 | An n x k dimensional array giving the class probabilities 48 | corresponding to ``log_alphas`` for each sample. 49 | """ 50 | alphas = np.exp(log_alphas) 51 | return alphas / np.expand_dims(np.sum(alphas, axis=-1), -1) 52 | 53 | 54 | # classes 55 | 56 | class ResourceTransformer(BaseEstimator, TransformerMixin): 57 | """Featurize the action pairs from the scruples resource. 58 | 59 | ``ResourceTransformer`` applies the same featurization pipeline 60 | (``self.transformer``) to both actions in an instance from the 61 | scruples resource and then takes the difference of their 62 | features. 63 | 64 | You can set parameters on the ``self.transformer`` attribute by 65 | prefixing parameters to ``ResourceTransformer`` with 66 | ``transformer__``. 67 | 68 | ``ResourceTransformer`` is particularly useful in front of linear 69 | models like logistic regression, since applying the model to the 70 | difference of the features is the same as taking the difference of 71 | the final scores. 72 | 73 | Attributes 74 | ---------- 75 | See `Parameters`_. 76 | 77 | Parameters 78 | ---------- 79 | transformer : Transformer 80 | The transformer to apply to the actions. 81 | """ 82 | def __init__( 83 | self, 84 | transformer: TransformerMixin 85 | ) -> None: 86 | self.transformer = transformer 87 | 88 | def set_params( 89 | self, 90 | **params: Dict[str, Any] 91 | ) -> 'ResourceTransformer': 92 | self_params = {} 93 | transformer_params = {} 94 | for param, value in params.items(): 95 | if param.startswith('transformer__'): 96 | transformer_params[param[13:]] = value 97 | else: 98 | self_params[param] = value 99 | # set the parameters on this instance 100 | super().set_params(**self_params) 101 | # set the parameters on the transformer attribute 102 | self.transformer.set_params(**transformer_params) 103 | 104 | return self 105 | 106 | def fit( 107 | self, 108 | X: pd.DataFrame, 109 | y: np.ndarray = None 110 | ) -> 'ResourceTransformer': 111 | """Fit the instance to ``X``. 112 | 113 | Fitting an instance of ``ResourceTransformer`` fits its 114 | ``self.transformer`` attribute to the data. The ``y`` argument 115 | is ignored. 116 | 117 | Parameters 118 | ---------- 119 | X : pd.DataFrame 120 | The data to fit. 121 | y : None 122 | An ignored argument. 123 | 124 | Returns 125 | ------- 126 | self : object 127 | The instance. 128 | """ 129 | X_ = pd.concat([X['action0'], X['action1']]) 130 | 131 | self.transformer.fit(X_) 132 | 133 | self._fitted = True 134 | 135 | return self 136 | 137 | def transform( 138 | self, 139 | X: pd.DataFrame 140 | ) -> Any: 141 | """Transform ``X``. 142 | 143 | Parameters 144 | ---------- 145 | X : pd.DataFrame 146 | The data to transform. 147 | 148 | Returns 149 | ------- 150 | Any 151 | The difference of the features for the actions derived by 152 | applying ``self.transformer`` to them. 153 | """ 154 | check_is_fitted(self, '_fitted') 155 | 156 | return ( 157 | self.transformer.transform(X['action1']) 158 | - self.transformer.transform(X['action0']) 159 | ) 160 | -------------------------------------------------------------------------------- /src/scruples/data/__init__.py: -------------------------------------------------------------------------------- 1 | """Definitions and utilities for working with the data.""" 2 | 3 | from . import ( 4 | comment, 5 | label_scores, 6 | labels, 7 | post, 8 | post_types, 9 | utils) 10 | -------------------------------------------------------------------------------- /src/scruples/data/comment.py: -------------------------------------------------------------------------------- 1 | """A class representing a comment.""" 2 | 3 | from typing import Optional 4 | 5 | import attr 6 | 7 | from .. import settings 8 | from . import utils 9 | from .labels import Label 10 | 11 | 12 | @attr.s(frozen=True, kw_only=True) 13 | class Comment: 14 | """A class representing a comment. 15 | 16 | Attributes 17 | ---------- 18 | label : Optional[Label] 19 | The label expressed by the comment. If no label can be extracted 20 | or multiple labels are extracted from the comment, then this 21 | attribute is ``None``. 22 | is_top_level : bool 23 | ``True`` if the comment is a top-level comment (i.e., a direct 24 | response to a link and not another comment). 25 | has_empty_body : bool 26 | ``True`` if the body text of the comment is empty. 27 | is_deleted : bool 28 | ``True`` if the comment is deleted. 29 | is_by_automoderator : bool 30 | ``True`` if the comment is by the AutoModerator. 31 | is_spam : bool 32 | ``True`` if the comment is is classified as spam (i.e., has 33 | irrelevant or meaningless content) by a set of rules-based 34 | filters. 35 | is_good : bool 36 | ``True`` if the comment is a good candidate for contributing a 37 | label. Some comments have good content but are not candidates 38 | for inclusion in the dataset due to certain dataset design 39 | decisions, such as requiring comments to be top-level. 40 | 41 | See `Parameters`_ for additional attributes. 42 | 43 | Parameters 44 | ---------- 45 | id : str 46 | A unique ID for the comment. 47 | subreddit_id : str 48 | The ID of the subreddit the comment was posted in. The ID has 49 | ``"t5_"`` prepended to it to represent the fact that it is a 50 | *subreddit* ID. 51 | subreddit : str 52 | The name of the subreddit the comment was posted in. 53 | link_id : str 54 | The ID of the post that the comment was made on. The ID has 55 | ``"t3_"`` prepended to it to represent the fact that it is a 56 | *post* ID. 57 | parent_id : str 58 | The ID of the parent object (either a comment or a post). If the 59 | parent object is a post, the ID will begin with ``"t3_"``. If 60 | the parent object is a comment, the ID will begin with 61 | ``"t1_"``. 62 | created_utc : int 63 | The time that the comment was created in seconds since the 64 | epoch. 65 | author : str 66 | The username of the comment's author. 67 | body : str 68 | The body text of the comment. 69 | score : int 70 | The score (upvotes minus downvotes) of the comment. 71 | controversiality : int 72 | The controversiality score for the comment. 73 | gilded : int 74 | The number of times the comment has been gilded. 75 | """ 76 | # identifying information 77 | id: str = attr.ib( 78 | validator=attr.validators.instance_of(str), 79 | converter=str) 80 | 81 | # reddit location 82 | subreddit_id: str = attr.ib( 83 | validator=attr.validators.instance_of(str), 84 | converter=str) 85 | subreddit: str = attr.ib( 86 | validator=attr.validators.instance_of(str), 87 | converter=str) 88 | link_id: str = attr.ib( 89 | validator=attr.validators.instance_of(str), 90 | converter=str) 91 | parent_id: str = attr.ib( 92 | validator=attr.validators.instance_of(str), 93 | converter=str) 94 | 95 | # creation 96 | created_utc: int = attr.ib( 97 | validator=attr.validators.instance_of(int), 98 | converter=int) 99 | author: str = attr.ib( 100 | validator=attr.validators.instance_of(str), 101 | converter=str) 102 | 103 | # content 104 | body: str = attr.ib( 105 | validator=attr.validators.instance_of(str), 106 | converter=str) 107 | 108 | # user interactions 109 | score: int = attr.ib( 110 | validator=attr.validators.instance_of(int), 111 | converter=int) 112 | controversiality: int = attr.ib( 113 | validator=attr.validators.instance_of(int), 114 | converter=int) 115 | gilded: int = attr.ib( 116 | validator=attr.validators.instance_of(int), 117 | converter=int) 118 | 119 | # computed content properties 120 | 121 | @utils.cached_property 122 | def label(self) -> Optional[Label]: 123 | return Label.extract_from_text(self.body) 124 | 125 | # computed properties for identifying comments to count in label 126 | # scores 127 | 128 | @utils.cached_property 129 | def is_top_level(self) -> bool: 130 | return self.parent_id == self.link_id 131 | 132 | @utils.cached_property 133 | def has_empty_body(self) -> bool: 134 | return self.body == "" 135 | 136 | @utils.cached_property 137 | def is_deleted(self) -> bool: 138 | return ( 139 | self.body == '[deleted]' 140 | or self.body == '[removed]' 141 | ) 142 | 143 | @utils.cached_property 144 | def is_by_automoderator(self) -> bool: 145 | return self.author == settings.AUTO_MODERATOR_NAME 146 | 147 | @utils.cached_property 148 | def is_spam(self) -> bool: 149 | # N.B. place cheaper predicates earlier so short-circuiting can 150 | # avoid evaluating more expensive predicates. 151 | return ( 152 | self.has_empty_body 153 | or self.is_deleted 154 | or self.is_by_automoderator 155 | ) 156 | 157 | @utils.cached_property 158 | def is_good(self) -> bool: 159 | # N.B. place cheaper predicates earlier so short-circuiting can 160 | # avoid evaluating more expensive predicates. 161 | return ( 162 | self.is_top_level 163 | and not self.is_spam 164 | ) 165 | -------------------------------------------------------------------------------- /src/scruples/data/label_scores.py: -------------------------------------------------------------------------------- 1 | """A class representing a collection of scores for the labels.""" 2 | 3 | from typing import Dict, Optional 4 | 5 | import attr 6 | 7 | from .. import settings 8 | from . import utils 9 | from .labels import BinarizedLabel, Label 10 | 11 | 12 | @attr.s(frozen=True, kw_only=True) 13 | class LabelScores: 14 | """A class representing scores for all the labels. 15 | 16 | Attributes 17 | ---------- 18 | binary_label_to_score : Dict[BinarizedLabel, int] 19 | A dictionary mapping each binarized label to its corresponding 20 | score. 21 | best_binarized_label : BinarizedLabel 22 | The overall highest scoring binarized label. Ties are broken 23 | arbitrarily. 24 | best_label : Label 25 | The overall highest scoring label. Ties are broken arbitrarily. 26 | has_all_zero_binarized_label_scores : bool 27 | ``True`` if all the binarized label scores are zero. 28 | has_all_zero_label_scores : bool 29 | ``True`` if all the label scores are zero. 30 | is_good : bool 31 | ``True`` if the label scores are considered good for inclusion 32 | in the final dataset. 33 | 34 | See `Parameters`_ for additional attributes. 35 | 36 | Parameters 37 | ---------- 38 | label_to_score : Dict[Label, int] 39 | A dictionary mapping each label to its corresponding score. 40 | """ 41 | label_to_score: Dict[Label, int] = attr.ib( 42 | validator=attr.validators.deep_mapping( 43 | key_validator=attr.validators.instance_of(Label), 44 | value_validator=attr.validators.instance_of(int))) 45 | 46 | # computed content properties 47 | 48 | @utils.cached_property 49 | def binarized_label_to_score(self) -> Dict[BinarizedLabel, int]: 50 | binarized_label_to_score = { 51 | binarized_label: 0 52 | for binarized_label in BinarizedLabel 53 | } 54 | for label, score in self.label_to_score.items(): 55 | binarized_label = BinarizedLabel.binarize(label) 56 | if binarized_label is not None: 57 | binarized_label_to_score[binarized_label] += score 58 | 59 | return binarized_label_to_score 60 | 61 | @utils.cached_property 62 | def best_binarized_label(self) -> Optional[BinarizedLabel]: 63 | return max( 64 | self.binarized_label_to_score.items(), 65 | key=lambda t: t[1] 66 | )[0] 67 | 68 | @utils.cached_property 69 | def best_label(self) -> Label: 70 | return max( 71 | self.label_to_score.items(), 72 | key=lambda t: t[1] 73 | )[0] 74 | 75 | # computed properties for identifying good label scores 76 | 77 | @utils.cached_property 78 | def has_all_zero_binarized_label_scores(self) -> bool: 79 | return all( 80 | v == 0 81 | for v in self.binarized_label_to_score.values() 82 | ) 83 | 84 | @utils.cached_property 85 | def has_all_zero_label_scores(self) -> bool: 86 | return all(v == 0 for v in self.label_to_score.values()) 87 | 88 | @utils.cached_property 89 | def is_good(self) -> bool: 90 | # N.B. place cheaper predicates earlier so short-circuiting can 91 | # avoid evaluating more expensive predicates. 92 | return ( 93 | not self.has_all_zero_binarized_label_scores 94 | and not self.has_all_zero_label_scores 95 | ) 96 | -------------------------------------------------------------------------------- /src/scruples/data/post_types.py: -------------------------------------------------------------------------------- 1 | """Definitions of post types and code for extracting them.""" 2 | 3 | import enum 4 | from typing import ( 5 | List, 6 | Optional) 7 | 8 | import regex 9 | 10 | 11 | @enum.unique 12 | class PostType(enum.Enum): 13 | """A post type. 14 | 15 | Posts are categorized into three types: 16 | 17 | 1. **HISTORICAL** : The author is asking if they're in the wrong, based 18 | on an event that has actually happened. 19 | 2. **HYPOTHETICAL** : The author is asking if they would be in the wrong, 20 | if they were to perform some action. 21 | 3. **META** : The post is discussing the subreddit itself. 22 | 23 | Attributes 24 | ---------- 25 | See `Parameters`_. 26 | 27 | Parameters 28 | ---------- 29 | index : int 30 | A (unique) numerical index assigned to the post type. 31 | reddit_name : str 32 | The name for the post type used on the subreddit from which the 33 | data originates. 34 | patterns : List[str] 35 | A list of strings, each representing a regular expression 36 | pattern used to extract that post type from a post's title. Note 37 | that patterns are compiled to regular expressions when they're 38 | bound to the ``PostType`` instance as an attribute. 39 | """ 40 | HISTORICAL = (0, 'AITA', [ 41 | r'\m(?i:AITAH?)\M', 42 | r'(?i:Am I the asshole){e<=2}' 43 | ]) 44 | HYPOTHETICAL = (1, 'WIBTA', [ 45 | r'\m(?i:WIBTAH?)\M', 46 | r'(?i:Would I be the asshole){e<=2}' 47 | ]) 48 | META = (2, 'META', [ 49 | r'\mMETA\M', 50 | r'\[(?i:META)\]' 51 | ]) 52 | 53 | @classmethod 54 | def extract_from_title( 55 | cls, 56 | title: str 57 | ) -> Optional['PostType']: 58 | """Return a post type extracted from ``title`` or ``None``. 59 | 60 | If a post type can be unambiguously extracted from the title, 61 | return it; otherwise, return ``None``. 62 | 63 | Parameters 64 | ---------- 65 | title : str 66 | The title string from which to extract the post type. 67 | 68 | Returns 69 | ------- 70 | Optional[PostType] 71 | The extracted post type. 72 | """ 73 | found_post_types = set() 74 | for post_type in cls: 75 | if post_type.in_(title): 76 | found_post_types.add(post_type) 77 | 78 | return found_post_types.pop() if len(found_post_types) == 1 else None 79 | 80 | def __init__( 81 | self, 82 | index: int, 83 | reddit_name: str, 84 | patterns: List[str] 85 | ) -> None: 86 | self.index = index 87 | self.reddit_name = reddit_name 88 | self.patterns = [regex.compile(pattern) for pattern in patterns] 89 | 90 | def in_( 91 | self, 92 | title: str 93 | ) -> bool: 94 | """Return ``True`` if ``title`` expresses the post type. 95 | 96 | Return ``True`` if the post type has any pattern that matches a 97 | substring from ``title``. 98 | 99 | Parameters 100 | ---------- 101 | title : str 102 | The title string to check for the post type. 103 | 104 | Returns 105 | ------- 106 | bool 107 | Whether or not the post type is in ``title``. 108 | """ 109 | return any(pattern.search(title) for pattern in self.patterns) 110 | -------------------------------------------------------------------------------- /src/scruples/data/utils.py: -------------------------------------------------------------------------------- 1 | """Utilities for working with data.""" 2 | 3 | from functools import wraps 4 | from typing import ( 5 | Any, 6 | Callable, 7 | Dict) 8 | 9 | import attr 10 | 11 | 12 | def instantiate_attrs_with_extra_kwargs( 13 | cls: Any, 14 | **kwargs: Dict[str, Any] 15 | ): 16 | """Return ``cls`` instantiated with ``kwargs`` ignoring extra kwargs. 17 | 18 | Parameters 19 | ---------- 20 | cls : Object 21 | An object that has been decorated with ``@attr.s``. 22 | **kwargs : Dict[str, Any] 23 | Any keyword arguments to use when instantiating ``cls``. Extra 24 | keyword arguments will be ignored. 25 | """ 26 | if not attr.has(cls): 27 | raise ValueError(f'{cls} must be decorated with @attr.s') 28 | 29 | attr_names = attr.fields_dict(cls).keys() 30 | return cls(**{ 31 | k: kwargs[k] 32 | for k in attr_names 33 | }) 34 | 35 | 36 | def cached_property(method: Callable): 37 | """Decorate a method to act as a cached property. 38 | 39 | This decorator converts a method into a cached property. It is 40 | intended to only be used on the methods of classes decorated with 41 | ``@attr.s`` where ``frozen=True``. This decorator works analogously 42 | to ``@property`` except it caches the computed value. 43 | 44 | Parameters 45 | ---------- 46 | method : Callable, required 47 | The method to decorate. ``method`` should take only one 48 | argument: ``self``. 49 | 50 | Returns 51 | ------- 52 | Callable 53 | The decoratored method. 54 | 55 | Notes 56 | ----- 57 | When used on a frozen attrs class, values for the property may 58 | safely be cached because the object is intended to be 59 | immutable. Additionally, the best place to store these cached values 60 | is on the object itself, so that they can be garbage collected when 61 | the object is. 62 | """ 63 | @wraps(method) 64 | def wrapper(self): 65 | cached_name = f'_{method.__name__}' 66 | if not hasattr(self, cached_name): 67 | value = method(self) 68 | 69 | # To get around the immutability of the instance, we have to 70 | # use __setattr__ from object. 71 | object.__setattr__(self, cached_name, value) 72 | 73 | return getattr(self, cached_name) 74 | return property(wrapper) 75 | -------------------------------------------------------------------------------- /src/scruples/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | """Classes and functions for reading the scruples dataset.""" 2 | 3 | from . import ( 4 | readers, 5 | transforms) 6 | -------------------------------------------------------------------------------- /src/scruples/demos/__init__.py: -------------------------------------------------------------------------------- 1 | """Demos for scruples.""" 2 | 3 | from . import ( 4 | norms, 5 | scoracle) 6 | -------------------------------------------------------------------------------- /src/scruples/demos/norms/__init__.py: -------------------------------------------------------------------------------- 1 | """A demo for reproducing normative decisions.""" 2 | 3 | from . import ( 4 | app, 5 | utils) 6 | -------------------------------------------------------------------------------- /src/scruples/demos/norms/static/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/scruples/c79697e97d24f5b43b5ff030a9b16011a32e4a5f/src/scruples/demos/norms/static/favicon.ico -------------------------------------------------------------------------------- /src/scruples/demos/norms/utils.py: -------------------------------------------------------------------------------- 1 | """Utilities for the ``norms`` demo.""" 2 | 3 | from typing import ( 4 | Any, 5 | Callable, 6 | List, 7 | Optional, 8 | Tuple) 9 | 10 | import torch 11 | 12 | 13 | class PredictionDataset(torch.utils.data.Dataset): 14 | """A PyTorch ``Dataset`` class for prediction. 15 | 16 | Parameters 17 | ---------- 18 | features : List[Any], required 19 | The list of features for the dataset's instances. 20 | transform : Optional[Callable], optional (default=None) 21 | The transformation to apply to the features. If ``None``, no 22 | transformation is applied. 23 | """ 24 | def __init__( 25 | self, 26 | features: List[Any], 27 | transform: Optional[Callable] = None, 28 | ) -> None: 29 | super().__init__() 30 | 31 | self.features = features 32 | self.transform = transform 33 | 34 | def __len__(self) -> int: 35 | return len(self.features) 36 | 37 | def __getitem__(self, key: int) -> Any: 38 | feature = self.features[key] 39 | 40 | if self.transform: 41 | feature = self.transform(feature) 42 | 43 | return feature 44 | -------------------------------------------------------------------------------- /src/scruples/demos/scoracle/__init__.py: -------------------------------------------------------------------------------- 1 | """A demo for the oracle score estimator.""" 2 | 3 | from . import app 4 | -------------------------------------------------------------------------------- /src/scruples/demos/scoracle/static/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/scruples/c79697e97d24f5b43b5ff030a9b16011a32e4a5f/src/scruples/demos/scoracle/static/favicon.ico -------------------------------------------------------------------------------- /src/scruples/extraction/__init__.py: -------------------------------------------------------------------------------- 1 | """Code and utilities for extracting data from and transforming text.""" 2 | 3 | from . import ( 4 | base, 5 | filters, 6 | normalizers, 7 | transformers) 8 | -------------------------------------------------------------------------------- /src/scruples/extraction/filters.py: -------------------------------------------------------------------------------- 1 | """Filter functions.""" 2 | 3 | import regex 4 | 5 | from . import base 6 | from .. import utils 7 | 8 | 9 | class EmptyStringFilter(base.LoggedCallable): 10 | """Filter out the empty string.""" 11 | 12 | def apply( 13 | self, 14 | s: str 15 | ) -> bool: 16 | """Return ``True`` if ``s`` is the empty string. 17 | 18 | Parameters 19 | ---------- 20 | s : str, required 21 | The string to test. 22 | 23 | Returns 24 | ------- 25 | bool 26 | ``True`` if ``s`` is the empty string, otherwise ``False``. 27 | """ 28 | return s == '' 29 | 30 | 31 | class TooFewCharactersFilter(base.LoggedCallable): 32 | """Filter strings with too few characters. 33 | 34 | Parameters 35 | ---------- 36 | min_chars : int, required 37 | The minimum number of characters to allow. 38 | """ 39 | 40 | def __init__( 41 | self, 42 | min_chars: int, 43 | *args, 44 | **kwargs 45 | ) -> None: 46 | super().__init__(*args, **kwargs) 47 | 48 | self.min_chars = min_chars 49 | 50 | def apply( 51 | self, 52 | s: str 53 | ) -> bool: 54 | """Return ``True`` if ``s`` has too few characters. 55 | 56 | Parameters 57 | ---------- 58 | s : str, required 59 | The string to test. 60 | 61 | Returns 62 | ------- 63 | bool 64 | ``True`` if ``s`` has fewer than ``self.min_chars`` 65 | characters, otherwise ``False``. 66 | """ 67 | if not isinstance(s, str): 68 | raise ValueError(f's ({s}) must be a string.') 69 | 70 | return len(s) < self.min_chars 71 | 72 | 73 | class TooFewWordsFilter(base.LoggedCallable): 74 | """Filter out strings with too few words. 75 | 76 | Parameters 77 | ---------- 78 | min_words : int, required 79 | The minimum number of words to allow. 80 | """ 81 | 82 | def __init__( 83 | self, 84 | min_words: int, 85 | *args, 86 | **kwargs 87 | ) -> None: 88 | super().__init__(*args, **kwargs) 89 | 90 | self.min_words = min_words 91 | 92 | def apply( 93 | self, 94 | s: str 95 | ) -> bool: 96 | """Return ``True`` if ``s`` has too few words. 97 | 98 | Parameters 99 | ---------- 100 | s : str, required 101 | The string to test. 102 | 103 | Returns 104 | ------- 105 | bool 106 | ``True`` if ``s`` has fewer than ``self.min_words`` words, 107 | otherwise ``False``. 108 | """ 109 | if not isinstance(s, str): 110 | raise ValueError(f's ({s}) must be a string.') 111 | 112 | return utils.count_words(s) < self.min_words 113 | 114 | 115 | class PrefixFilter(base.LoggedCallable): 116 | """Filter out strings with a certain prefix. 117 | 118 | Parameters 119 | ---------- 120 | prefix : str, required 121 | The prefix to look for when filtering strings. 122 | case_sensitive : bool, optional (default=False) 123 | Whether or not to match the prefix in a case-sensitive fashion. 124 | """ 125 | 126 | def __init__( 127 | self, 128 | prefix: str, 129 | case_sensitive: bool = False, 130 | *args, 131 | **kwargs 132 | ) -> None: 133 | super().__init__(*args, **kwargs) 134 | 135 | self.prefix = prefix 136 | self.case_sensitive = case_sensitive 137 | 138 | def apply( 139 | self, 140 | s: str 141 | ) -> bool: 142 | """Return ``True`` if ``s`` begins with ``prefix``. 143 | 144 | Parameters 145 | ---------- 146 | s : str, required 147 | The string to test. 148 | 149 | Returns 150 | ------- 151 | bool 152 | ``True`` if ``s`` begins with ``self.prefix``, otherwise 153 | ``False``. 154 | """ 155 | if not isinstance(s, str): 156 | raise ValueError(f's ({s}) must be a string.') 157 | 158 | if not self.case_sensitive: 159 | prefix, s = self.prefix.lower(), s.lower() 160 | else: 161 | prefix = self.prefix 162 | 163 | return s.startswith(prefix) 164 | 165 | 166 | class StartsWithGerundFilter(base.LoggedCallable): 167 | """Filter strings which have no gerunds in the first few words.""" 168 | 169 | _pattern = regex.compile( 170 | r'^(?:[\p{Ps}\p{Pi}"\']?\m[^\s]+\M[\p{Pe}\p{Pf}"\']? ){0,2}' 171 | r'[\p{Pi}"\']?\m[^\s]+ing\M[\p{Pf}"\']?') 172 | 173 | def apply( 174 | self, 175 | s: str 176 | ) -> bool: 177 | """Return ``True`` if ``s`` has no gerunds in the first 3 words. 178 | 179 | Parameters 180 | ---------- 181 | s : str, required 182 | The string to test. 183 | 184 | Returns 185 | ------- 186 | bool 187 | ``True`` if ``s`` has no gerunds in the first three words, 188 | otherwise ``False``. 189 | """ 190 | return self._pattern.match(s) is None 191 | 192 | 193 | class WhWordFilter(base.LoggedCallable): 194 | """Filter strings which start with a wh-word.""" 195 | 196 | _wh_word_patterns = [ 197 | 'why', 198 | 'who', 199 | 'which', 200 | 'what', 201 | 'where', 202 | 'when', 203 | 'how' 204 | ] 205 | # compile the _wh_word_patterns to regexes 206 | _wh_word_patterns = [ 207 | regex.compile(r'(?i:[\p{{Pi}}"\']?\m{}\M[\p{{Pf}}"\']?)'.format(w)) 208 | for w in _wh_word_patterns 209 | ] 210 | 211 | def apply( 212 | self, 213 | s: str 214 | ) -> bool: 215 | """Return ``True`` if ``s`` starts with a wh-word. 216 | 217 | Parameters 218 | ---------- 219 | s : str, required 220 | The string to test. 221 | 222 | Returns 223 | ------- 224 | bool 225 | ``True`` if ``s`` begins with a wh-word, otherwise 226 | ``False``. 227 | """ 228 | return any(pattern.match(s) for pattern in self._wh_word_patterns) 229 | -------------------------------------------------------------------------------- /src/scruples/extraction/transformers.py: -------------------------------------------------------------------------------- 1 | """Transformations for extracting data.""" 2 | 3 | from pattern.en import ( 4 | conjugate, 5 | PRESENT, 6 | INDICATIVE, 7 | PROGRESSIVE) 8 | import spacy 9 | 10 | from . import base 11 | 12 | 13 | # TODO: remove the following lines once this issue: 14 | # https://github.com/clips/pattern/issues/243, is resolved. 15 | try: conjugate('give') 16 | except: pass 17 | 18 | 19 | class GerundifyingTransformer(base.LoggedCallable): 20 | """Transform the input into a gerund phrase.""" 21 | 22 | _nlp = spacy.load('en', disable=['ner']) 23 | 24 | @classmethod 25 | def _is_root( 26 | cls, 27 | token: spacy.tokens.token.Token 28 | ) -> bool: 29 | return token.dep_ == 'ROOT' and token.pos_ == 'VERB' 30 | 31 | @classmethod 32 | def _is_rootlike( 33 | cls, 34 | token: spacy.tokens.token.Token 35 | ) -> bool: 36 | return ( 37 | cls._is_root(token) 38 | or any(cls._is_root(c) for c in token.conjuncts) 39 | ) 40 | 41 | @classmethod 42 | def _conjugate( 43 | cls, 44 | text: str 45 | ) -> str: 46 | # handle some special cases 47 | if text in ["'m", 'm']: 48 | text = 'am' 49 | elif text == 'left': 50 | text = 'leave' 51 | else: 52 | pass 53 | 54 | return conjugate( 55 | verb=text, 56 | tense=PRESENT, 57 | person=None, 58 | number=None, 59 | mood=INDICATIVE, 60 | aspect=PROGRESSIVE) 61 | 62 | def apply( 63 | self, 64 | x: str 65 | ) -> str: 66 | token_strs = [] 67 | for token in self._nlp(x): 68 | if token.sent.start != 0: 69 | # skip conjugating verbs that aren't in the first 70 | # sentence 71 | token_strs.append(token.text_with_ws) 72 | continue 73 | 74 | if self._is_rootlike(token): 75 | # conjugate the token if it is like the root verb 76 | token_str = self._conjugate(token.text)\ 77 | + token.text_with_ws[len(token.text):] 78 | elif token.dep_ == 'nsubj' and self._is_rootlike(token.head): 79 | # remove the subject attached to the root 80 | continue 81 | elif token.dep_ == 'aux' and self._is_rootlike(token.head): 82 | # remove auxiliary verbs attached to the root 83 | continue 84 | elif token.text in ["n't", 'nt'] and self._is_rootlike(token.head): 85 | # fix forms of "not" coming from contractions 86 | token_str = 'not' + token.text_with_ws[len(token.text):] 87 | else: 88 | # nothing to do here 89 | token_str = token.text_with_ws 90 | 91 | if ( 92 | token_str.startswith('not') 93 | and len(token_strs) > 0 94 | and token_strs[-1] in ['being', 'doing', 'having'] 95 | ): 96 | # swap "not" with the preceding verb 97 | token_str, token_strs[-1] = token_strs[-1], token_str 98 | 99 | token_strs.append(token_str) 100 | 101 | return ''.join(token_strs) 102 | -------------------------------------------------------------------------------- /src/scruples/scripts/__init__.py: -------------------------------------------------------------------------------- 1 | """Scripts for scruples.""" 2 | 3 | import logging 4 | 5 | import click 6 | 7 | from . import ( 8 | analyze, 9 | demo, 10 | evaluate, 11 | make) 12 | from .. import utils 13 | 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | # main function 19 | 20 | @click.group() 21 | @click.option( 22 | '--verbose', 23 | is_flag=True, 24 | help='Set the log level to DEBUG.') 25 | def scruples(verbose: bool) -> None: 26 | """The command line interface for scruples.""" 27 | utils.configure_logging(verbose=verbose) 28 | 29 | 30 | # register subcommands to the command group 31 | 32 | subcommands = [ 33 | analyze.analyze, 34 | demo.demo, 35 | evaluate.evaluate, 36 | make.make 37 | ] 38 | 39 | for subcommand in subcommands: 40 | scruples.add_command(subcommand) 41 | -------------------------------------------------------------------------------- /src/scruples/scripts/analyze/__init__.py: -------------------------------------------------------------------------------- 1 | """Scripts for performing analyses.""" 2 | 3 | import logging 4 | 5 | import click 6 | 7 | from . import ( 8 | corpus, 9 | oracle_estimator, 10 | resource) 11 | 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | # main function 17 | 18 | @click.group() 19 | def analyze(): 20 | """Run an analysis.""" 21 | pass 22 | 23 | 24 | # register subcommands to the command group 25 | 26 | subcommands = [ 27 | corpus.corpus, 28 | oracle_estimator.oracle_estimator, 29 | resource.resource 30 | ] 31 | 32 | for subcommand in subcommands: 33 | analyze.add_command(subcommand) 34 | -------------------------------------------------------------------------------- /src/scruples/scripts/analyze/corpus/__init__.py: -------------------------------------------------------------------------------- 1 | """Scripts for analyzing the scruples corpus.""" 2 | 3 | import logging 4 | 5 | import click 6 | 7 | from . import ( 8 | extractions, 9 | human_performance, 10 | oracle_performance, 11 | predictions, 12 | statistics) 13 | 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | # main function 19 | 20 | @click.group() 21 | def corpus(): 22 | """Analyze the scruples corpus.""" 23 | pass 24 | 25 | 26 | # register subcommands to the command group 27 | 28 | subcommands = [ 29 | extractions.extractions, 30 | human_performance.human_performance, 31 | oracle_performance.oracle_performance, 32 | predictions.predictions, 33 | statistics.statistics 34 | ] 35 | 36 | for subcommand in subcommands: 37 | corpus.add_command(subcommand) 38 | -------------------------------------------------------------------------------- /src/scruples/scripts/analyze/corpus/human_performance.py: -------------------------------------------------------------------------------- 1 | """Estimate human performance for the scruples corpus.""" 2 | 3 | import collections 4 | import json 5 | import logging 6 | import random 7 | 8 | import click 9 | import tqdm 10 | 11 | from .... import settings 12 | from ....baselines.metrics import METRICS 13 | from ....data.comment import Comment 14 | from ....data.post import Post 15 | from ....data.labels import Label 16 | from ....data.utils import instantiate_attrs_with_extra_kwargs 17 | 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | # main function 23 | 24 | @click.command() 25 | @click.argument( 26 | 'comments_path', 27 | type=click.Path(exists=True, file_okay=True, dir_okay=False)) 28 | @click.argument( 29 | 'posts_path', 30 | type=click.Path(exists=True, file_okay=True, dir_okay=False)) 31 | @click.argument( 32 | 'split_path', 33 | type=click.Path(exists=True, file_okay=True, dir_okay=False)) 34 | @click.argument( 35 | 'output_path', 36 | type=click.Path(exists=False, file_okay=True, dir_okay=False)) 37 | def human_performance( 38 | comments_path: str, 39 | posts_path: str, 40 | split_path: str, 41 | output_path: str 42 | ) -> None: 43 | """Estimate human performance on the scruples corpus. 44 | 45 | Read in the comments from COMMENTS_PATH, posts from POSTS_PATH, and 46 | the split from SPLIT_PATH, then estimate human performance metrics 47 | and write them to OUTPUT_PATH. 48 | 49 | Human performance is estimated by holding out one comment from each 50 | post, and using its label as the human prediction. If the post has 51 | only a single comment, then the majority label is used as the 52 | prediction. 53 | """ 54 | logger.info('Reading in comments.') 55 | 56 | link_id_to_comments = collections.defaultdict(list) 57 | with click.open_file(comments_path, 'r') as comments_file: 58 | for ln in tqdm.tqdm(comments_file.readlines(), **settings.TQDM_KWARGS): 59 | comment = instantiate_attrs_with_extra_kwargs( 60 | Comment, 61 | **json.loads(ln)) 62 | 63 | # IDs are usually prefixed with something like "t1_", 64 | # "t2_", etc. to denote what type of object it is. Slice 65 | # off the first 3 characters to remove this prefix from 66 | # the link id because it will not be on the posts' IDs 67 | # when we join the comments to them. 68 | link_id_to_comments[comment.link_id[3:]].append(comment) 69 | 70 | logger.info('Reading in posts.') 71 | 72 | split_post_ids = set() 73 | with click.open_file(split_path, 'r') as split_file: 74 | for ln in split_file: 75 | split_post_ids.add(json.loads(ln)['post_id']) 76 | 77 | posts = [] 78 | with click.open_file(posts_path, 'r') as posts_file: 79 | for ln in tqdm.tqdm(posts_file.readlines(), **settings.TQDM_KWARGS): 80 | kwargs = json.loads(ln) 81 | post = instantiate_attrs_with_extra_kwargs( 82 | Post, 83 | comments=link_id_to_comments[kwargs['id']], 84 | **kwargs) 85 | 86 | if post.id in split_post_ids: 87 | posts.append(post) 88 | 89 | logger.info('Computing human performance.') 90 | 91 | human_preds = [] 92 | gold_labels = [] 93 | for post in tqdm.tqdm(posts, **settings.TQDM_KWARGS): 94 | post_labels = [ 95 | comment.label.index 96 | for comment in post.comments 97 | if comment.is_good and comment.label is not None 98 | ] 99 | random.shuffle(post_labels) 100 | 101 | if len(post_labels) > 1: 102 | human_preds.append(post_labels[-1]) 103 | gold_labels.append(collections.Counter( 104 | post_labels[:-1]).most_common(1)[0][0]) 105 | elif len(post_labels) == 1: 106 | # predict with the majority label 107 | human_preds.append(Label.OTHER.index) 108 | gold_labels.append(post_labels[0]) 109 | else: 110 | raise ValueError('Found a post without a label.') 111 | 112 | with open(output_path, 'w') as metrics_file: 113 | json.dump({ 114 | key: metric( 115 | y_true=gold_labels, 116 | y_pred=human_preds) 117 | for key, (_, metric, scorer_kwargs) in METRICS.items() 118 | if not scorer_kwargs['needs_proba'] 119 | }, metrics_file) 120 | -------------------------------------------------------------------------------- /src/scruples/scripts/analyze/corpus/oracle_performance.py: -------------------------------------------------------------------------------- 1 | """Estimate the oracle performance on the scruples corpus.""" 2 | 3 | import json 4 | import logging 5 | 6 | import click 7 | import numpy as np 8 | from scipy import stats 9 | import tqdm 10 | 11 | from .... import utils, settings 12 | from ....data.labels import Label 13 | from ....baselines.metrics import METRICS 14 | 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | 19 | # constants 20 | 21 | REPORT_TEMPLATE =\ 22 | """Scruples Corpus Oracle Performance 23 | ================================== 24 | Oracle performance on the scruples corpus. 25 | 26 | 27 | Oracle Metrics 28 | -------------- 29 | The following metrics provide estimates of the oracle performance, i.e. the 30 | performance in each metric attained when you use the _true_ underlying label 31 | distribution for that instance to predict the labels. 32 | 33 | We obtain these estimates using an empirical Bayesian methodology: first, we 34 | fit a dirichlet-multinomial model to the data, where each instance's label 35 | distribution is drawn from a dirichlet prior. Then, for each instance we 36 | condition on the observed labels and compute the expected performance of 37 | predicting with the true distribution over the posterior. 38 | 39 | {metrics_report} 40 | """ 41 | 42 | 43 | # main function 44 | 45 | @click.command() 46 | @click.argument( 47 | 'dataset_path', 48 | type=click.Path(exists=True, file_okay=True, dir_okay=False)) 49 | @click.argument( 50 | 'output_path', 51 | type=click.Path(exists=False, file_okay=True, dir_okay=False)) 52 | def oracle_performance( 53 | dataset_path: str, 54 | output_path: str 55 | ) -> None: 56 | """Estimate oracle performance and write a report. 57 | 58 | Read in the dataset from DATASET_PATH, estimate the oracle 59 | performance and write the results to OUTPUT_PATH. 60 | """ 61 | label_name_to_idx = { 62 | label.name: label.index 63 | for label in Label 64 | } 65 | # Step 1: Read in the dataset. 66 | with click.open_file(dataset_path, 'r') as dataset_file: 67 | labels = [] 68 | label_scores = [] 69 | for ln in dataset_file: 70 | row = json.loads(ln) 71 | 72 | labels.append(label_name_to_idx[row['label']]) 73 | 74 | scores = [0 for _ in Label] 75 | for label_name, score in row['label_scores'].items(): 76 | scores[label_name_to_idx[label_name]] = score 77 | label_scores.append(scores) 78 | labels = np.array(labels) 79 | label_scores = np.array(label_scores) 80 | 81 | # Step 2: Estimate the dirichlet-multinomial parameters. 82 | params = utils.estimate_dirichlet_multinomial_parameters( 83 | label_scores) 84 | 85 | # Step 3: Estimate the expected performance. 86 | metric_name_to_value = {} 87 | for name, metric, scorer_kwargs in METRICS.values(): 88 | logger.info(f'Computing estimate for {name}.') 89 | # estimate performance on the usual metrics 90 | value_samples = [] 91 | for _ in tqdm.tqdm(range(10000), **settings.TQDM_KWARGS): 92 | true_scores = np.array([ 93 | stats.dirichlet.rvs([ 94 | a + x 95 | for a, x in zip(params, scores) 96 | ], size=1)[0] 97 | for scores in label_scores 98 | ]) 99 | value = metric( 100 | y_true=labels, 101 | y_pred=true_scores 102 | if scorer_kwargs['needs_proba'] 103 | else np.argmax(true_scores, axis=1)) 104 | value_samples.append(value) 105 | metric_name_to_value[name] = np.mean(value_samples) 106 | if 'xentropy' in metric_name_to_value: 107 | raise ValueError( 108 | 'METRICS should not have a key named "xentropy". This issue' 109 | ' is a bug in the library, please notify the maintainers.') 110 | 111 | # estimate performance on the xentropy, which requires soft 112 | # ground-truth labels 113 | logger.info('Computing estimate for xentropy.') 114 | value_samples = [] 115 | for _ in tqdm.tqdm(range(10000), **settings.TQDM_KWARGS): 116 | true_scores = np.array([ 117 | stats.dirichlet.rvs([ 118 | a + x 119 | for a, x in zip(params, scores) 120 | ], size=1)[0] 121 | for scores in label_scores 122 | ]) 123 | value = utils.xentropy( 124 | y_true=label_scores / np.sum(label_scores, axis=1).reshape(-1, 1), 125 | y_pred=true_scores) 126 | value_samples.append(value) 127 | metric_name_to_value['xentropy'] = np.mean(value_samples) 128 | 129 | # Step 4: Write the report. 130 | metric_name_width = 1 + max( 131 | len(name) 132 | for name in metric_name_to_value.keys()) 133 | metrics_report = '\n'.join( 134 | f'{name: <{metric_name_width}}: {value:.4f}' 135 | for name, value in metric_name_to_value.items()) 136 | with click.open_file(output_path, 'w') as output_file: 137 | output_file.write( 138 | REPORT_TEMPLATE.format( 139 | metrics_report=metrics_report)) 140 | -------------------------------------------------------------------------------- /src/scruples/scripts/analyze/oracle_estimator.py: -------------------------------------------------------------------------------- 1 | """Simulation experiments for the oracle performance estimator.""" 2 | 3 | import collections 4 | import json 5 | import logging 6 | 7 | import click 8 | import numpy as np 9 | from scipy import stats 10 | from sklearn import metrics 11 | import tqdm 12 | 13 | from scruples import settings, utils 14 | from scruples.data.labels import Label 15 | from scruples.dataset.readers import ScruplesCorpus 16 | 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | 21 | METRICS = [ 22 | ( 23 | 'accuracy', 24 | metrics.accuracy_score, 25 | lambda pss: np.argmax(pss, axis=-1) 26 | ), 27 | ( 28 | 'f1 (macro)', 29 | lambda y_pred, y_true: metrics.f1_score( 30 | y_pred=y_pred, y_true=y_true, 31 | average='macro'), 32 | lambda pss: np.argmax(pss, axis=-1) 33 | ), 34 | ( 35 | 'xentropy', 36 | utils.xentropy, 37 | lambda pss: pss 38 | ) 39 | ] 40 | """The metrics to run in the experiment. 41 | 42 | Each tuple contains: 43 | 44 | (name, metric, make_predictions) 45 | 46 | """ 47 | 48 | 49 | @click.command() 50 | @click.argument( 51 | 'corpus_dir', 52 | type=click.Path(exists=True, file_okay=False, dir_okay=True)) 53 | @click.argument( 54 | 'output_path', 55 | type=click.Path(exists=False, file_okay=True, dir_okay=False)) 56 | def oracle_estimator( 57 | corpus_dir: str, 58 | output_path: str 59 | ) -> None: 60 | """Conduct simulation experiments for the oracle estimator. 61 | 62 | Read the corpus from CORPUS_DIR, conduct different simulation experiments 63 | to evaluate the robustness and effectiveness of the oracle performance 64 | estimator, and write the results to OUTPUT_PATH. 65 | """ 66 | # Read in the corpus. 67 | 68 | logger.info(f'Reading data from {corpus_dir}.') 69 | 70 | _, _, _, label_scores = ScruplesCorpus(data_dir=corpus_dir).dev 71 | label_scores = label_scores.values 72 | 73 | n_total = len(label_scores) 74 | alphas = utils.estimate_dirichlet_multinomial_parameters(label_scores) 75 | 76 | # Create the scenarios to simulate. 77 | 78 | logger.info('Creating simulation scenarios.') 79 | 80 | scenarios = [] 81 | # scenario 1: running the estimator on the corpus 82 | ns = np.sum(label_scores, axis=-1) 83 | pss = stats.dirichlet.rvs(alpha=alphas, size=n_total) 84 | ys = np.array([stats.multinomial.rvs(n, ps) for n, ps in zip(ns, pss)]) 85 | scenarios.append({ 86 | 'name': 'corpus', 87 | 'ns': ns, 88 | 'pss': pss, 89 | 'ys': ys 90 | }) 91 | # scenario 2: running the estimator with 3 annotations per example 92 | ns = 3 * np.ones(n_total) 93 | pss = stats.dirichlet.rvs(alpha=alphas, size=n_total) 94 | ys = np.array([stats.multinomial.rvs(n, ps) for n, ps in zip(ns, pss)]) 95 | scenarios.append({ 96 | 'name': '3 annotations', 97 | 'ns': ns, 98 | 'pss': pss, 99 | 'ys': ys 100 | }) 101 | # scenario 3: running the estimator with a non-dirichlet prior 102 | ns = np.sum(label_scores, axis=-1) 103 | pss = np.concatenate([ 104 | stats.dirichlet.rvs(alpha=[2, 1, 1, 1, 1], size=n_total//3), 105 | stats.dirichlet.rvs(alpha=[1, 1, 3, 1, 1], size=n_total//3), 106 | stats.dirichlet.rvs(alpha=[1, 1, 1, 1, 2], size=n_total - 2 * n_total//3) 107 | ]) 108 | np.random.shuffle(pss) 109 | ys = np.array([stats.multinomial.rvs(n, ps) for n, ps in zip(ns, pss)]) 110 | scenarios.append({ 111 | 'name': 'non-dirichlet prior', 112 | 'ns': ns, 113 | 'pss': pss, 114 | 'ys': ys 115 | }) 116 | 117 | # Run the simulations. 118 | 119 | logger.info('Running simulations.') 120 | 121 | results = collections.defaultdict(dict) 122 | for scenario in tqdm.tqdm(scenarios, **settings.TQDM_KWARGS): 123 | name = scenario['name'] 124 | ns = scenario['ns'] 125 | pss = scenario['pss'] 126 | ys = scenario['ys'] 127 | 128 | results[name]['oracle'] = { 129 | metric_name: metric( 130 | y_pred=make_predictions(pss), 131 | y_true=make_predictions( 132 | ys / np.expand_dims(np.sum(ys, axis=-1), axis=-1))) 133 | for metric_name, metric, make_predictions in METRICS 134 | } 135 | 136 | # estimate oracle performance 137 | estimated_alphas = utils.estimate_dirichlet_multinomial_parameters(ys) 138 | 139 | results[name]['estimate'] = { 140 | metric_name: utils.oracle_performance( 141 | ys=ys, 142 | metric=metric, 143 | make_predictions=make_predictions, 144 | n_samples=10000)[0] 145 | for metric_name, metric, make_predictions in tqdm.tqdm( 146 | METRICS, **settings.TQDM_KWARGS) 147 | } 148 | 149 | # Write the results to disk. 150 | 151 | logger.info(f'Writing results to {output_path}.') 152 | 153 | with click.open_file(output_path, 'w') as output_file: 154 | json.dump(results, output_file) 155 | -------------------------------------------------------------------------------- /src/scruples/scripts/analyze/resource/__init__.py: -------------------------------------------------------------------------------- 1 | """Scripts for analyzing the scruples resource.""" 2 | 3 | import logging 4 | 5 | import click 6 | 7 | from . import ( 8 | human_performance, 9 | latent_traits, 10 | oracle_performance, 11 | predictions, 12 | topics, 13 | verbs) 14 | 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | 19 | # main function 20 | 21 | @click.group() 22 | def resource(): 23 | """Analyze the scruples resource.""" 24 | pass 25 | 26 | 27 | # register subcommands to the command group 28 | 29 | subcommands = [ 30 | human_performance.human_performance, 31 | latent_traits.latent_traits, 32 | oracle_performance.oracle_performance, 33 | predictions.predictions, 34 | topics.topics, 35 | verbs.verbs 36 | ] 37 | 38 | for subcommand in subcommands: 39 | resource.add_command(subcommand) 40 | -------------------------------------------------------------------------------- /src/scruples/scripts/analyze/resource/human_performance.py: -------------------------------------------------------------------------------- 1 | """Estimate human performance for the scruples resource.""" 2 | 3 | import json 4 | import logging 5 | 6 | import click 7 | 8 | from ....baselines.metrics import METRICS 9 | 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | # main function 15 | 16 | @click.command() 17 | @click.argument( 18 | 'split_path', 19 | type=click.Path(exists=True, file_okay=True, dir_okay=False)) 20 | @click.argument( 21 | 'output_path', 22 | type=click.Path(exists=False, file_okay=True, dir_okay=False)) 23 | def human_performance( 24 | split_path: str, 25 | output_path: str 26 | ) -> None: 27 | """Estimate human performance on the scruples resource. 28 | 29 | Read in the split from SPLIT_PATH, then estimate human performance 30 | metrics and write them to OUTPUT_PATH. 31 | 32 | Human performance is computed by comparing the majority vote label 33 | of the human performance annotators to the majority vote label of 34 | the gold annotators. 35 | """ 36 | logger.info('Computing human performance.') 37 | 38 | human_preds = [] 39 | gold_labels = [] 40 | with click.open_file(split_path, 'r') as split_file: 41 | for ln in split_file: 42 | row = json.loads(ln) 43 | human_preds.append(row['human_perf_label']) 44 | gold_labels.append(row['gold_label']) 45 | 46 | with open(output_path, 'w') as metrics_file: 47 | json.dump({ 48 | key: metric( 49 | y_true=gold_labels, 50 | y_pred=human_preds) 51 | for key, (_, metric, scorer_kwargs) in METRICS.items() 52 | if not scorer_kwargs['needs_proba'] 53 | }, metrics_file) 54 | -------------------------------------------------------------------------------- /src/scruples/scripts/analyze/resource/oracle_performance.py: -------------------------------------------------------------------------------- 1 | """Estimate the oracle performance on the scruples resource.""" 2 | 3 | import json 4 | import logging 5 | 6 | import click 7 | import numpy as np 8 | from scipy import stats 9 | import tqdm 10 | 11 | from .... import utils, settings 12 | from ....baselines.metrics import METRICS 13 | 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | # constants 19 | 20 | REPORT_TEMPLATE =\ 21 | """Scruples Resource Oracle Performance 22 | ==================================== 23 | Oracle performance on the scruples resource. 24 | 25 | 26 | Oracle Metrics 27 | -------------- 28 | The following metrics provide estimates of the oracle performance, i.e. the 29 | performance in each metric attained when you use the _true_ underlying label 30 | distribution for that instance to predict the labels. 31 | 32 | We obtain these estimates using an empirical Bayesian methodology: first, we 33 | fit a dirichlet-multinomial model to the data, where each instance's label 34 | distribution is drawn from a dirichlet prior. Then, for each instance we 35 | condition on the observed labels and compute the expected performance of 36 | predicting with the true distribution over the posterior. 37 | 38 | {metrics_report} 39 | """ 40 | 41 | 42 | # main function 43 | 44 | @click.command() 45 | @click.argument( 46 | 'dataset_path', 47 | type=click.Path(exists=True, file_okay=True, dir_okay=False)) 48 | @click.argument( 49 | 'output_path', 50 | type=click.Path(exists=False, file_okay=True, dir_okay=False)) 51 | def oracle_performance( 52 | dataset_path: str, 53 | output_path: str 54 | ) -> None: 55 | """Estimate oracle performance and write a report. 56 | 57 | Read in the dataset from DATASET_PATH, estimate the oracle 58 | performance and write the results to OUTPUT_PATH. 59 | """ 60 | # Step 1: Read in the dataset. 61 | with click.open_file(dataset_path, 'r') as dataset_file: 62 | labels = [] 63 | label_scores = [] 64 | label_annotations_all = [] 65 | for ln in dataset_file: 66 | row = json.loads(ln) 67 | 68 | labels.append(row['gold_label']) 69 | label_scores.append(row['gold_annotations']) 70 | label_annotations_all.append( 71 | [ 72 | i + j 73 | for i, j in zip( 74 | row['gold_annotations'], 75 | row['human_perf_annotations'] 76 | ) 77 | ]) 78 | 79 | labels = np.array(labels) 80 | label_scores = np.array(label_scores) 81 | label_annotations_all = np.array(label_annotations_all) 82 | 83 | # Step 2: Estimate the dirichlet-multinomial parameters. 84 | params = utils.estimate_dirichlet_multinomial_parameters( 85 | label_annotations_all) 86 | 87 | # Step 3: Estimate the expected performance. 88 | metric_name_to_value = {} 89 | for name, metric, scorer_kwargs in METRICS.values(): 90 | logger.info(f'Computing estimate for {name}.') 91 | # estimate performance on the usual metrics 92 | value_samples = [] 93 | for _ in tqdm.tqdm(range(10000), **settings.TQDM_KWARGS): 94 | true_scores = np.array([ 95 | stats.dirichlet.rvs([ 96 | a + x 97 | for a, x in zip(params, scores) 98 | ], size=1)[0] 99 | for scores in label_scores 100 | ]) 101 | value = metric( 102 | y_true=labels, 103 | y_pred=true_scores 104 | if scorer_kwargs['needs_proba'] 105 | else np.argmax(true_scores, axis=1)) 106 | value_samples.append(value) 107 | metric_name_to_value[name] = np.mean(value_samples) 108 | if 'xentropy' in metric_name_to_value: 109 | raise ValueError( 110 | 'METRICS should not have a key named "xentropy". This issue' 111 | ' is a bug in the library, please notify the maintainers.') 112 | 113 | # estimate performance on the xentropy, which requires soft 114 | # ground-truth labels 115 | logger.info('Computing estimate for xentropy.') 116 | value_samples = [] 117 | for _ in tqdm.tqdm(range(10000), **settings.TQDM_KWARGS): 118 | true_scores = np.array([ 119 | stats.dirichlet.rvs([ 120 | a + x 121 | for a, x in zip(params, scores) 122 | ], size=1)[0] 123 | for scores in label_scores 124 | ]) 125 | value = utils.xentropy( 126 | y_true=label_scores / np.sum(label_scores, axis=1).reshape(-1, 1), 127 | y_pred=true_scores) 128 | value_samples.append(value) 129 | metric_name_to_value['xentropy'] = np.mean(value_samples) 130 | 131 | # Step 4: Write the report. 132 | metric_name_width = 1 + max( 133 | len(name) 134 | for name in metric_name_to_value.keys()) 135 | metrics_report = '\n'.join( 136 | f'{name: <{metric_name_width}}: {value:.4f}' 137 | for name, value in metric_name_to_value.items()) 138 | with click.open_file(output_path, 'w') as output_file: 139 | output_file.write( 140 | REPORT_TEMPLATE.format( 141 | metrics_report=metrics_report)) 142 | -------------------------------------------------------------------------------- /src/scruples/scripts/analyze/resource/topics.py: -------------------------------------------------------------------------------- 1 | """Infer topics for the action descriptions in the resource.""" 2 | 3 | import json 4 | import logging 5 | 6 | import click 7 | from sklearn.decomposition import LatentDirichletAllocation 8 | from sklearn.feature_extraction.text import CountVectorizer 9 | from sklearn.pipeline import Pipeline 10 | 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | # constants 16 | 17 | N_COMPONENTS = 5 18 | """The number of topics to fit.""" 19 | # N.B. This value was chosen interactively by examining the resulting 20 | # topics and log-likelihood the topic model produces on the resource's 21 | # dev set for various values of n_components. 22 | 23 | 24 | # main function 25 | 26 | @click.command() 27 | @click.argument( 28 | 'split_path', 29 | type=click.Path(exists=True, file_okay=True, dir_okay=False)) 30 | @click.argument( 31 | 'output_path', 32 | type=click.Path(exists=False, file_okay=True, dir_okay=False)) 33 | def topics( 34 | split_path: str, 35 | output_path: str 36 | ) -> None: 37 | """Create topics for the action descriptions. 38 | 39 | Read in the actions split from SPLIT_PATH, apply a Latent Dirichlet 40 | Allocation based topic model to the descriptions of all the actions, 41 | then write out the top words from the resulting topics to 42 | OUTPUT_PATH. 43 | """ 44 | logger.info('Fitting the topic model to the action descriptions.') 45 | 46 | model = Pipeline([ 47 | ( 48 | 'vectorizer', 49 | CountVectorizer( 50 | input='content', 51 | encoding='utf-8', 52 | decode_error='strict', 53 | strip_accents=None, 54 | lowercase=True, 55 | preprocessor=None, 56 | tokenizer=None, 57 | stop_words='english', 58 | token_pattern=r'(?u)\b\w\w+\b', 59 | ngram_range=(1, 1), 60 | analyzer='word', 61 | max_df=1.0, 62 | min_df=2, 63 | max_features=None, 64 | vocabulary=None, 65 | binary=False, 66 | ) 67 | ), 68 | ( 69 | 'topic_model', 70 | LatentDirichletAllocation(n_components=N_COMPONENTS) 71 | ) 72 | ]) 73 | 74 | with click.open_file(split_path, 'r') as split_file: 75 | descriptions = [ 76 | action['description'] 77 | for ln in split_file 78 | for action in json.loads(ln)['actions'] 79 | ] 80 | 81 | model.fit(descriptions) 82 | 83 | feature_names = model.named_steps['vectorizer'].get_feature_names() 84 | components = model.named_steps['topic_model'].components_ 85 | 86 | topics_ = [ 87 | [ 88 | { 89 | 'word': word, 90 | 'alpha': alpha 91 | } 92 | for alpha, word in sorted( 93 | zip(component, feature_names) 94 | )[-25:][::-1] 95 | ] 96 | for component in components 97 | ] 98 | 99 | logger.info("Writing out the topics' top words.") 100 | 101 | with click.open_file(output_path, 'w') as output_file: 102 | for topic in topics_: 103 | output_file.write(json.dumps(topic) + '\n') 104 | 105 | logger.info("Finished topic analysis.") 106 | -------------------------------------------------------------------------------- /src/scruples/scripts/demo/__init__.py: -------------------------------------------------------------------------------- 1 | """Scripts for demos.""" 2 | 3 | import logging 4 | 5 | import click 6 | 7 | from . import ( 8 | norms, 9 | scoracle) 10 | 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | # main function 16 | 17 | @click.group() 18 | def demo(): 19 | """Run a demo's server.""" 20 | pass 21 | 22 | 23 | # register subcommands to the command group 24 | 25 | subcommands = [ 26 | norms.norms, 27 | scoracle.scoracle, 28 | ] 29 | 30 | for subcommand in subcommands: 31 | demo.add_command(subcommand) 32 | -------------------------------------------------------------------------------- /src/scruples/scripts/demo/norms.py: -------------------------------------------------------------------------------- 1 | """Run the norms demo's server.""" 2 | 3 | import logging 4 | 5 | import click 6 | from gevent.pywsgi import WSGIServer 7 | 8 | from ... import settings 9 | from ...demos.norms.app import ( 10 | app, 11 | get_device, 12 | load_model) 13 | 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | @click.command() 19 | @click.option( 20 | '--expose', is_flag=True, 21 | help='Whether to expose the server to the internet, or run on localhost.') 22 | @click.option( 23 | '--port', type=int, default=5050, 24 | help='The port on which to serve the demo. Defaults to 5050.') 25 | def norms( 26 | expose: bool, 27 | port: int 28 | ) -> None: 29 | """Serve the norms demo. 30 | 31 | In order to run this server, you must set the following environment 32 | variables: 33 | 34 | \b 35 | SCRUPLES_NORMS_ACTIONS_BASELINE : The baseline to use for the resource 36 | (actions). 37 | SCRUPLES_NORMS_ACTIONS_MODEL : The path to the saved pretrained 38 | model to use for predicting the actions. 39 | SCRUPLES_NORMS_CORPUS_BASELINE : The baseline to use for the corpus. 40 | SCRUPLES_NORMS_CORPUS_MODEL : The path to the saved pretrained 41 | model to use for predicting the corpus. 42 | SCRUPLES_NORMS_PREDICT_BATCH_SIZE : The batch size to use for 43 | prediction. 44 | SCRUPLES_NORMS_GPU_IDS : A comma separated list of GPU IDs to 45 | use. If none are provided, then the CPU will be used instead. 46 | 47 | """ 48 | # load the device and model 49 | 50 | get_device() 51 | 52 | logger.info('Loading the Actions model.') 53 | load_model(dataset='resource',) 54 | 55 | logger.info('Loading the Corpus model.') 56 | load_model(dataset='corpus') 57 | 58 | # start the server 59 | 60 | ip = '0.0.0.0' if expose else '127.0.0.1' 61 | 62 | logger.info(f'Running norms server on http://{ip}:{port}/') 63 | 64 | WSGIServer((ip, port), app).serve_forever() 65 | -------------------------------------------------------------------------------- /src/scruples/scripts/demo/scoracle.py: -------------------------------------------------------------------------------- 1 | """Run the scoracle demo's server.""" 2 | 3 | import logging 4 | 5 | import click 6 | from gevent.pywsgi import WSGIServer 7 | 8 | from scruples.demos.scoracle.app import app 9 | 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | @click.command() 15 | @click.option( 16 | '--expose', is_flag=True, 17 | help='Whether to expose the server to the internet, or run on localhost.') 18 | @click.option( 19 | '--port', type=int, default=5000, 20 | help='The port on which to serve the demo. Defaults to 5000.') 21 | def scoracle( 22 | expose: bool, 23 | port: int 24 | ) -> None: 25 | """Serve the scoracle demo.""" 26 | ip = '0.0.0.0' if expose else '127.0.0.1' 27 | 28 | logger.info(f'Running scoracle server on http://{ip}:{port}/') 29 | 30 | WSGIServer((ip, port), app).serve_forever() 31 | -------------------------------------------------------------------------------- /src/scruples/scripts/evaluate/__init__.py: -------------------------------------------------------------------------------- 1 | """Scripts for evaluating models on scruples.""" 2 | 3 | import logging 4 | 5 | import click 6 | 7 | from . import ( 8 | resource, 9 | corpus) 10 | 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | # main function 16 | 17 | @click.group() 18 | def evaluate(): 19 | """Evaluate models on scruples.""" 20 | pass 21 | 22 | 23 | # register subcommands to the command group 24 | 25 | subcommands = [ 26 | resource.resource, 27 | corpus.corpus 28 | ] 29 | 30 | for subcommand in subcommands: 31 | evaluate.add_command(subcommand) 32 | -------------------------------------------------------------------------------- /src/scruples/scripts/evaluate/corpus/__init__.py: -------------------------------------------------------------------------------- 1 | """Scripts for evaluating models on the scruples corpus.""" 2 | 3 | import logging 4 | 5 | import click 6 | 7 | from . import ( 8 | predict_lm, 9 | run_shallow, 10 | tune_lm) 11 | 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | # main function 17 | 18 | @click.group() 19 | def corpus(): 20 | """Evaluate baseline models on the scruples corpus.""" 21 | pass 22 | 23 | 24 | # register subcommands to the command group 25 | 26 | subcommands = [ 27 | predict_lm.predict_lm, 28 | run_shallow.run_shallow, 29 | tune_lm.tune_lm 30 | ] 31 | 32 | for subcommand in subcommands: 33 | corpus.add_command(subcommand) 34 | -------------------------------------------------------------------------------- /src/scruples/scripts/evaluate/corpus/run_shallow.py: -------------------------------------------------------------------------------- 1 | """Run shallow baseline models on the scruples corpus.""" 2 | 3 | import collections 4 | import json 5 | import logging 6 | import os 7 | from typing import List 8 | 9 | import click 10 | import dill 11 | from sklearn.metrics import make_scorer 12 | from skopt import BayesSearchCV 13 | import tqdm 14 | 15 | from .... import settings, baselines 16 | from ....baselines.metrics import METRICS 17 | from ....dataset.readers import ScruplesCorpus 18 | 19 | 20 | logger = logging.getLogger(__name__) 21 | 22 | 23 | # main function 24 | 25 | @click.command() 26 | @click.argument( 27 | 'data_dir', 28 | type=click.Path(exists=True, file_okay=False, dir_okay=True)) 29 | @click.argument( 30 | 'results_dir', 31 | type=click.Path(exists=False, file_okay=False, dir_okay=True)) 32 | @click.argument( 33 | 'splits', type=click.Choice(ScruplesCorpus.SPLITS), nargs=-1) 34 | @click.option( 35 | '--metric', 36 | type=click.Choice(METRICS.keys()), 37 | default='log_loss', 38 | help='The metric to use for hyper-parameter tuning. Defaults to' 39 | ' log_loss.') 40 | @click.option( 41 | '--n-iter', type=int, default=128, 42 | help='The number of iterations of Bayesian optimization to run when' 43 | ' tuning baseline hyper-parameters. Defaults to 128.') 44 | @click.option( 45 | '--n-points', type=int, default=8, 46 | help='The number of points to evaluate in parallel during Bayesian' 47 | ' optimization. Defaults to 8.') 48 | @click.option( 49 | '--n-folds', type=int, default=4, 50 | help='The number of cross-validation folds to use. Defaults to 4.') 51 | @click.option( 52 | '--n-jobs', type=int, default=0, 53 | help='The number of parallel processes to use for tuning' 54 | ' hyper-parameters. At most n_folds * n_points processses can' 55 | ' be used at a given time. If 0, then the same number of' 56 | ' processes as CPUs will be used. Defaults to 0.') 57 | def run_shallow( 58 | data_dir: str, 59 | results_dir: str, 60 | splits: List[str], 61 | metric: str, 62 | n_iter: int, 63 | n_points: int, 64 | n_folds: int, 65 | n_jobs: int 66 | ) -> None: 67 | """Evaluate shallow baselines on the scruples corpus. 68 | 69 | Train shallow baseline models on the scruples corpus, reading the 70 | dataset from DATA_DIR, and writing trained models, logs, and other 71 | results to RESULTS_DIR. Performance is reported for each split 72 | provided as an argument. 73 | """ 74 | # Step 1: Manage and construct paths. 75 | 76 | logger.info('Creating the results directory.') 77 | 78 | os.makedirs(results_dir) 79 | model_paths = {} 80 | metrics_paths = collections.defaultdict(dict) 81 | predictions_paths = collections.defaultdict(dict) 82 | for baseline in baselines.corpus.SHALLOW_BASELINES.keys(): 83 | os.makedirs(os.path.join(results_dir, baseline)) 84 | model_paths[baseline] = os.path.join( 85 | results_dir, baseline, 'model.pkl') 86 | for split in splits: 87 | os.makedirs(os.path.join(results_dir, baseline, split)) 88 | metrics_paths[baseline][split] = os.path.join( 89 | results_dir, baseline, split, 'metrics.json') 90 | predictions_paths[baseline][split] = os.path.join( 91 | results_dir, baseline, split, 'predictions.jsonl') 92 | 93 | # Step 2: Load the data. 94 | 95 | logger.info(f'Loading data from {data_dir}.') 96 | 97 | dataset = ScruplesCorpus(data_dir=data_dir) 98 | 99 | # Step 3: Run the baselines. 100 | 101 | logger.info('Running the baselines.') 102 | 103 | for baseline, (Model, hyper_parameter_space) in tqdm.tqdm( 104 | baselines.corpus.SHALLOW_BASELINES.items(), 105 | **settings.TQDM_KWARGS 106 | ): 107 | # tune the hyper-parameters and train the model 108 | ids, features, labels, label_scores = dataset.train 109 | if hyper_parameter_space: 110 | model = BayesSearchCV( 111 | Model, 112 | hyper_parameter_space, 113 | scoring=make_scorer( 114 | score_func=METRICS[metric][1], 115 | **METRICS[metric][2]), 116 | n_iter=n_iter, 117 | n_points=n_points, 118 | cv=n_folds, 119 | n_jobs=os.cpu_count() if n_jobs == 0 else n_jobs, 120 | refit=True) 121 | else: 122 | model = Model 123 | model.fit(features, labels) 124 | 125 | # Step 4: Save the model. 126 | 127 | with open(model_paths[baseline], 'wb') as model_file: 128 | dill.dump(model, model_file) 129 | 130 | # Step 5: Run evaluation on the splits. 131 | 132 | for split in splits: 133 | ids, features, labels, label_scores = getattr(dataset, split) 134 | 135 | predictions = model.predict(features) 136 | probabilities = model.predict_proba(features) 137 | 138 | with open(metrics_paths[baseline][split], 'w') as metrics_file: 139 | json.dump( 140 | { 141 | key: metric( 142 | y_true=labels, 143 | y_pred=probabilities 144 | if scorer_kwargs['needs_proba'] 145 | else predictions) 146 | for key, (_, metric, scorer_kwargs) in METRICS.items() 147 | }, 148 | metrics_file) 149 | 150 | with open(predictions_paths[baseline][split], 'w')\ 151 | as predictions_file: 152 | for id_, probs, prediction in zip( 153 | ids, probabilities, predictions 154 | ): 155 | predictions_file.write( 156 | json.dumps({ 157 | 'id': id_, 158 | 'label': prediction, 159 | 'label_scores': { 160 | class_: prob.tolist() 161 | for class_, prob 162 | in zip(model.classes_, probs) 163 | } 164 | }) + '\n') 165 | -------------------------------------------------------------------------------- /src/scruples/scripts/evaluate/resource/__init__.py: -------------------------------------------------------------------------------- 1 | """Scripts for evaluating models on the scruples resource.""" 2 | 3 | import logging 4 | 5 | import click 6 | 7 | from . import ( 8 | predict_lm, 9 | run_shallow, 10 | tune_lm) 11 | 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | # main function 17 | 18 | @click.group() 19 | def resource(): 20 | """Evaluate baseline models on the scruples resource.""" 21 | pass 22 | 23 | 24 | # register subcommands to the command group 25 | 26 | subcommands = [ 27 | predict_lm.predict_lm, 28 | run_shallow.run_shallow, 29 | tune_lm.tune_lm 30 | ] 31 | 32 | for subcommand in subcommands: 33 | resource.add_command(subcommand) 34 | -------------------------------------------------------------------------------- /src/scruples/scripts/evaluate/resource/run_shallow.py: -------------------------------------------------------------------------------- 1 | """Run shallow baseline models on the scruples resource.""" 2 | 3 | import collections 4 | import json 5 | import logging 6 | import os 7 | from typing import List 8 | 9 | import click 10 | import dill 11 | from sklearn.metrics import make_scorer 12 | from skopt import BayesSearchCV 13 | import tqdm 14 | 15 | from .... import settings, baselines 16 | from ....baselines.metrics import METRICS 17 | from ....dataset.readers import ScruplesResource 18 | 19 | 20 | logger = logging.getLogger(__name__) 21 | 22 | 23 | # main function 24 | 25 | @click.command() 26 | @click.argument( 27 | 'data_dir', 28 | type=click.Path(exists=True, file_okay=False, dir_okay=True)) 29 | @click.argument( 30 | 'results_dir', 31 | type=click.Path(exists=False, file_okay=False, dir_okay=True)) 32 | @click.argument( 33 | 'splits', type=click.Choice(ScruplesResource.SPLITS), nargs=-1) 34 | @click.option( 35 | '--metric', 36 | type=click.Choice(METRICS.keys()), 37 | default='log_loss', 38 | help='The metric to use for hyper-parameter tuning. Defaults to' 39 | ' log_loss.') 40 | @click.option( 41 | '--n-iter', type=int, default=128, 42 | help='The number of iterations of Bayesian optimization to run when' 43 | ' tuning baseline hyper-parameters. Defaults to 128.') 44 | @click.option( 45 | '--n-points', type=int, default=8, 46 | help='The number of points to evaluate in parallel during Bayesian' 47 | ' optimization. Defaults to 8.') 48 | @click.option( 49 | '--n-folds', type=int, default=4, 50 | help='The number of cross-validation folds to use. Defaults to 4.') 51 | @click.option( 52 | '--n-jobs', type=int, default=0, 53 | help='The number of parallel processes to use for tuning' 54 | ' hyper-parameters. At most n_folds * n_points processses can' 55 | ' be used at a given time. If 0, then the same number of' 56 | ' processes as CPUs will be used. Defaults to 0.') 57 | def run_shallow( 58 | data_dir: str, 59 | results_dir: str, 60 | splits: List[str], 61 | metric: str, 62 | n_iter: int, 63 | n_points: int, 64 | n_folds: int, 65 | n_jobs: int 66 | ) -> None: 67 | """Evaluate shallow baselines on the scruples resource. 68 | 69 | Train shallow baseline models on the scruples resource, reading 70 | the dataset from DATA_DIR, and writing trained models, logs, and 71 | other results to RESULTS_DIR. Performance is reported for each split 72 | provided as an argument. 73 | """ 74 | # Step 1: Manage and construct paths. 75 | 76 | logger.info('Creating the results directory.') 77 | 78 | os.makedirs(results_dir) 79 | model_paths = {} 80 | metrics_paths = collections.defaultdict(dict) 81 | predictions_paths = collections.defaultdict(dict) 82 | for baseline in baselines.resource.SHALLOW_BASELINES.keys(): 83 | os.makedirs(os.path.join(results_dir, baseline)) 84 | model_paths[baseline] = os.path.join( 85 | results_dir, baseline, 'model.pkl') 86 | for split in splits: 87 | os.makedirs(os.path.join(results_dir, baseline, split)) 88 | metrics_paths[baseline][split] = os.path.join( 89 | results_dir, baseline, split, 'metrics.json') 90 | predictions_paths[baseline][split] = os.path.join( 91 | results_dir, baseline, split, 'predictions.jsonl') 92 | 93 | # Step 2: Load the data. 94 | 95 | logger.info(f'Loading the data from {data_dir}.') 96 | 97 | dataset = ScruplesResource(data_dir=data_dir) 98 | 99 | # Step 3: Run the baselines. 100 | 101 | logger.info('Running the baselines.') 102 | 103 | for baseline, (Model, hyper_parameter_space) in tqdm.tqdm( 104 | baselines.resource.SHALLOW_BASELINES.items(), 105 | **settings.TQDM_KWARGS 106 | ): 107 | # tune the hyper-parameters and train the model 108 | ids, features, labels, label_scores = dataset.train 109 | if hyper_parameter_space: 110 | model = BayesSearchCV( 111 | Model, 112 | hyper_parameter_space, 113 | scoring=make_scorer( 114 | score_func=METRICS[metric][1], 115 | **METRICS[metric][2]), 116 | n_iter=n_iter, 117 | n_points=n_points, 118 | cv=n_folds, 119 | n_jobs=os.cpu_count() if n_jobs == 0 else n_jobs, 120 | refit=True) 121 | else: 122 | model = Model 123 | model.fit(features, labels) 124 | 125 | # Step 4: Save the model. 126 | 127 | with open(model_paths[baseline], 'wb') as model_file: 128 | dill.dump(model, model_file) 129 | 130 | # Step 5: Run evaluation on the splits. 131 | 132 | for split in splits: 133 | ids, features, labels, label_scores = getattr(dataset, split) 134 | 135 | predictions = model.predict(features) 136 | probabilities = model.predict_proba(features) 137 | 138 | with open(metrics_paths[baseline][split], 'w') as metrics_file: 139 | json.dump( 140 | { 141 | key: metric( 142 | y_true=labels, 143 | y_pred=probabilities 144 | if scorer_kwargs['needs_proba'] 145 | else predictions) 146 | for key, (_, metric, scorer_kwargs) in METRICS.items() 147 | }, 148 | metrics_file) 149 | 150 | with open(predictions_paths[baseline][split], 'w')\ 151 | as predictions_file: 152 | for id_, probs, prediction in zip( 153 | ids, probabilities, predictions 154 | ): 155 | predictions_file.write( 156 | json.dumps({ 157 | 'id': id_, 158 | 'label': prediction.tolist(), 159 | 'label_scores': probs.tolist() 160 | }) + '\n') 161 | -------------------------------------------------------------------------------- /src/scruples/scripts/make/__init__.py: -------------------------------------------------------------------------------- 1 | """Scripts for making the different components of scruples.""" 2 | 3 | import logging 4 | 5 | import click 6 | 7 | from . import ( 8 | resource, 9 | corpus) 10 | 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | # main function 16 | 17 | @click.group() 18 | def make(): 19 | """Make different components of scruples.""" 20 | pass 21 | 22 | 23 | subcommands = [ 24 | resource.resource, 25 | corpus.corpus 26 | ] 27 | 28 | for subcommand in subcommands: 29 | make.add_command(subcommand) 30 | -------------------------------------------------------------------------------- /src/scruples/scripts/make/corpus.py: -------------------------------------------------------------------------------- 1 | """Make the scruples corpus from raw reddit data. 2 | 3 | This script takes in posts and comments from the reddit API and creates 4 | the scruples dataset. 5 | """ 6 | 7 | import collections 8 | import json 9 | import logging 10 | import os 11 | import random 12 | 13 | import attr 14 | import click 15 | import tqdm 16 | 17 | from ... import settings, utils 18 | from ...data.comment import Comment 19 | from ...data.post import Post 20 | from ...data.utils import instantiate_attrs_with_extra_kwargs 21 | 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | 26 | # main function 27 | 28 | @click.command() 29 | @click.argument( 30 | 'comments_path', 31 | type=click.Path(exists=True, file_okay=True, dir_okay=False)) 32 | @click.argument( 33 | 'posts_path', 34 | type=click.Path(exists=True, file_okay=True, dir_okay=False)) 35 | @click.argument( 36 | 'corpus_dir', 37 | type=click.Path(exists=False, file_okay=False, dir_okay=True)) 38 | def corpus( 39 | comments_path: str, 40 | posts_path: str, 41 | corpus_dir: str 42 | ) -> None: 43 | """Create the scruples corpus and write it to CORPUS_DIR. 44 | 45 | Read in the reddit posts from POSTS_PATH and comments from 46 | COMMENTS_PATH, create the scruples corpus, and write it to 47 | CORPUS_DIR. 48 | """ 49 | # Create the output directory. 50 | os.makedirs(corpus_dir) 51 | 52 | # Step 1: Read in the comments and index them by their link ids. 53 | logger.info('Reading the comments.') 54 | 55 | link_id_to_comments = collections.defaultdict(list) 56 | with click.open_file(comments_path, 'r') as comments_file: 57 | for ln in tqdm.tqdm(comments_file.readlines(), **settings.TQDM_KWARGS): 58 | comment = instantiate_attrs_with_extra_kwargs( 59 | Comment, 60 | **json.loads(ln)) 61 | 62 | # IDs are usually prefixed with something like "t1_", 63 | # "t2_", etc. to denote what type of object it is. Slice 64 | # off the first 3 characters to remove this prefix from 65 | # the link id because it will not be on the posts' IDs 66 | # when we join the comments to them. 67 | link_id_to_comments[comment.link_id[3:]].append(comment) 68 | 69 | # Step 2: Read in the posts and join them with their comments. 70 | logger.info('Reading the posts.') 71 | 72 | posts = [] 73 | with click.open_file(posts_path, 'r') as posts_file: 74 | for ln in tqdm.tqdm(posts_file.readlines(), **settings.TQDM_KWARGS): 75 | kwargs = json.loads(ln) 76 | post = instantiate_attrs_with_extra_kwargs( 77 | Post, 78 | comments=link_id_to_comments[kwargs['id']], 79 | **kwargs) 80 | 81 | posts.append(post) 82 | 83 | # Step 3: Write the posts to disk. 84 | logger.info('Writing the posts to disk.') 85 | 86 | processed_posts_path = os.path.join(corpus_dir, settings.POSTS_FILENAME) 87 | with open(processed_posts_path, 'w') as processed_posts_file: 88 | for post in posts: 89 | processed_posts_file.write(json.dumps(attr.asdict(post)) + '\n') 90 | 91 | # Step 4: Filter out bad posts. 92 | logger.info('Filtering out bad posts.') 93 | 94 | dataset_posts = [ 95 | post 96 | for post in tqdm.tqdm(posts, **settings.TQDM_KWARGS) 97 | if post.is_good 98 | ] 99 | 100 | # Step 5: Create the splits then write them to disk. 101 | logger.info('Creating splits and writing them to disk.') 102 | 103 | # Shuffle dataset_posts so that the splits will be random. 104 | random.shuffle(dataset_posts) 105 | 106 | if [split['size'] for split in settings.SPLITS].count(None) > 1: 107 | raise ValueError( 108 | 'The settings.SPLITS constant should have at most ONE split' 109 | ' with a size of None.') 110 | 111 | # Make sure that the split with a size of ``None`` will be processed 112 | # last. 113 | splits = [ 114 | split 115 | for split in settings.SPLITS 116 | if split['size'] is not None 117 | ] + [ 118 | split 119 | for split in settings.SPLITS 120 | if split['size'] is None 121 | ] 122 | for split in splits: 123 | split_path = os.path.join( 124 | corpus_dir, 125 | settings.CORPUS_FILENAME_TEMPLATE.format(split=split['name'])) 126 | with open(split_path, 'w') as split_file: 127 | if split['size'] is None: 128 | split_posts = dataset_posts 129 | dataset_posts = [] 130 | else: 131 | split_posts = dataset_posts[:split['size']] 132 | dataset_posts = dataset_posts[split['size']:] 133 | for post in tqdm.tqdm(split_posts, **settings.TQDM_KWARGS): 134 | instance = { 135 | 'id': utils.make_id(), 136 | 'post_id': post.id, 137 | 'action': 138 | attr.asdict(post.action) 139 | if post.action is not None else 140 | None, 141 | 'title': post.title, 142 | 'text': post.original_text, 143 | 'post_type': post.post_type.name, 144 | 'label_scores': { 145 | label.name: score 146 | for label, score 147 | in post.label_scores.label_to_score.items() 148 | }, 149 | 'label': post.label_scores.best_label.name, 150 | 'binarized_label_scores': { 151 | binarized_label.name: score 152 | for binarized_label, score 153 | in post.label_scores.binarized_label_to_score.items() 154 | }, 155 | 'binarized_label': post.label_scores\ 156 | .best_binarized_label.name 157 | } 158 | 159 | split_file.write(json.dumps(instance) + '\n') 160 | -------------------------------------------------------------------------------- /src/scruples/scripts/make/resource/__init__.py: -------------------------------------------------------------------------------- 1 | """Scripts for making the scruples resource.""" 2 | 3 | import logging 4 | 5 | import click 6 | 7 | from . import ( 8 | dataset, 9 | proposals) 10 | 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | # main function 16 | 17 | @click.group() 18 | def resource(): 19 | """Make different components of the scruples resource.""" 20 | pass 21 | 22 | 23 | subcommands = [ 24 | dataset.dataset, 25 | proposals.proposals 26 | ] 27 | 28 | for subcommand in subcommands: 29 | resource.add_command(subcommand) 30 | -------------------------------------------------------------------------------- /src/scruples/scripts/make/resource/proposals.py: -------------------------------------------------------------------------------- 1 | """Create candidate instances for the scruples resource. 2 | 3 | This script takes in the scruples corpus and creates candidate 4 | instances for the scruples resource of ranked action pairs. The 5 | proposals then must be annotated using Mechanical Turk. 6 | """ 7 | 8 | import json 9 | import logging 10 | import os 11 | import random 12 | 13 | import click 14 | import tqdm 15 | 16 | from .... import settings, utils 17 | from ....data.action import Action 18 | 19 | 20 | logger = logging.getLogger(__name__) 21 | 22 | 23 | # main function 24 | 25 | @click.command() 26 | @click.argument( 27 | 'corpus_dir', 28 | type=click.Path(exists=True, file_okay=False, dir_okay=True)) 29 | @click.argument( 30 | 'proposals_dir', 31 | type=click.Path(exists=False, file_okay=False, dir_okay=True)) 32 | @click.option( 33 | '--rounds', type=int, default=3, 34 | help='The number of rounds of random matchings to run. Each round' 35 | ' uses every action from the resource once. So, k rounds will' 36 | ' use an action at most k times. Duplicate action pairs are' 37 | ' dropped after each round. Defaults to 3.') 38 | def proposals( 39 | corpus_dir: str, 40 | proposals_dir: str, 41 | rounds: int 42 | ) -> None: 43 | """Propose and write instances for the resource to PROPOSALS_DIR. 44 | 45 | Read in the scruples corpus from CORPUS_DIR and then for each 46 | split in the corpus, randomly pair actions together and write them 47 | to PROPOSALS_DIR in a format ready for being annotated on Mechanical 48 | Turk. 49 | """ 50 | # Create the output directory 51 | os.makedirs(proposals_dir) 52 | 53 | # Iterate over each split of the corpus, creating the resource data 54 | # for each. 55 | for split in settings.SPLITS: 56 | logger.info(f'Reading the {split["name"]} split from the corpus.') 57 | 58 | ids_actions = [] 59 | corpus_split_path = os.path.join( 60 | corpus_dir, 61 | settings.CORPUS_FILENAME_TEMPLATE.format(split=split['name'])) 62 | with open(corpus_split_path, 'r') as corpus_split_file: 63 | for ln in tqdm.tqdm(corpus_split_file, **settings.TQDM_KWARGS): 64 | row = json.loads(ln) 65 | if row['action'] is not None: 66 | action = Action(**row['action']) 67 | if action.is_good: 68 | ids_actions.append((row['id'], action)) 69 | 70 | logger.info(f'Computing random matchings for {rounds} rounds.') 71 | 72 | seen_pairs = set() 73 | instances = [] 74 | for _ in range(rounds): 75 | random.shuffle(ids_actions) 76 | for i in range(0, len(ids_actions) - 1, 2): 77 | # check if the pair is a duplicate 78 | 79 | id0, _ = ids_actions[i] 80 | id1, _ = ids_actions[i+1] 81 | # canonicalize the order of the pair 82 | id0, id1 = sorted([id0, id1]) 83 | # check if the pair has been seen before 84 | if (id0, id1) in seen_pairs: 85 | continue 86 | else: 87 | seen_pairs.add((id0, id1)) 88 | 89 | # add the instance to the proposals 90 | 91 | instances.append({ 92 | 'id': utils.make_id(), 93 | 'actions': [ 94 | { 95 | 'id': id_, 96 | 'description': action.description 97 | } 98 | for id_, action in ids_actions[i:i+2] 99 | ] 100 | }) 101 | 102 | logger.info(f'Writing proposals to {proposals_dir}.') 103 | 104 | proposals_split_path = os.path.join( 105 | proposals_dir, 106 | settings.PROPOSALS_FILENAME_TEMPLATE.format(split=split['name'])) 107 | with open(proposals_split_path, 'w') as proposals_split_file: 108 | for i in range(0, len(instances), settings.N_INSTANCES_PER_HIT): 109 | proposals_split_file.write(json.dumps({ 110 | 'instances': instances[i:i+settings.N_INSTANCES_PER_HIT] 111 | }) + '\n') 112 | -------------------------------------------------------------------------------- /src/scruples/settings.py: -------------------------------------------------------------------------------- 1 | """Constants and settings.""" 2 | 3 | import os 4 | 5 | 6 | # subreddit related constants 7 | 8 | AUTO_MODERATOR_NAME = 'AutoModerator' 9 | """The name of the AutoModerator bot.""" 10 | 11 | 12 | # dataset parameters 13 | 14 | # splits 15 | 16 | SPLITS = [ 17 | { 18 | 'name': 'train', 19 | 'size': None 20 | }, 21 | { 22 | 'name': 'dev', 23 | 'size': 2500 24 | }, 25 | { 26 | 'name': 'test', 27 | 'size': 2500 28 | } 29 | ] 30 | """Definitions for the various dataset splits. 31 | 32 | A ``'size'`` of ``None`` signifies that the all the data that is not 33 | used in the other splits should go to that split. 34 | """ 35 | # N.B. This variable is the single-source-of-truth for the splits, their 36 | # names and their sizes (in terms of the number of posts used to create 37 | # them). If this variable is modified, make sure to update the doc 38 | # strings for ``ScruplesCorpus``, ``ScruplesCorpusDataset``. 39 | # 40 | # Note that the resource splits will have a different number of 41 | # instances than the number of posts used to create them, in 42 | # general. So, the 'size' attribute doesn't necessarily give the size of 43 | # the resource splits. 44 | 45 | # corpus settings 46 | 47 | CORPUS_FILENAME_TEMPLATE = '{split}.scruples-corpus.jsonl' 48 | """A template string for the corpus's split filenames.""" 49 | 50 | POSTS_FILENAME = 'all.scruples-posts.jsonl' 51 | """The filename for the file containing all the posts.""" 52 | 53 | 54 | # resource settings 55 | 56 | PROPOSALS_FILENAME_TEMPLATE = '{split}.scruples-proposals.jsonl' 57 | """A template string for the resource proposals' split filenames.""" 58 | 59 | RESOURCE_FILENAME_TEMPLATE = '{split}.scruples-actions.jsonl' 60 | """A template string for the resource's split filenames.""" 61 | 62 | N_ANNOTATORS_FOR_GOLD_LABELS = 5 63 | """The number of annotators to use for creating the gold labels.""" 64 | 65 | MIN_AGREEMENT = 5 66 | """The minimum number of gold annotators required to agree. 67 | 68 | The minimum number of gold annotators required to agree for the instance 69 | to be considered non-controversial. 70 | """ 71 | 72 | N_ANNOTATORS_FOR_HUMAN_PERFORMANCE = 5 73 | """The number of annotators to use for evaluating human performance.""" 74 | 75 | N_INSTANCES_PER_HIT = 20 76 | """The number of instances annotated in a single HIT.""" 77 | 78 | 79 | # evaluation 80 | 81 | LOSS_TYPES = [ 82 | 'xentropy-hard', 83 | 'xentropy-soft', 84 | 'xentropy-full', 85 | 'dirichlet-multinomial' 86 | ] 87 | """The different loss types for deep baseline models. 88 | 89 | ``"xentropy-hard"`` uses cross-entropy on the hard labels derived from 90 | the plurality. 91 | 92 | ``"xentropy-soft"`` uses cross-entropy against soft labels derived from 93 | averaging the individual labels together. 94 | 95 | ``"xentropy-full"`` uses the full negative log-likelihood objective with 96 | all of the annotations. So, unlike ``"xentropy-soft"``, it doesn't 97 | average the annotations then compute the cross-entropy, but simply sums 98 | the contributions from each label. This loss is equivalent to 99 | ``"xentropy-soft"`` in the case where each instance has the same number 100 | of annotations. 101 | 102 | ``"dirichlet-multinomial"`` uses a dirichlet-multinomial likelihood 103 | where the model is predicting the parameters for the dirichlet 104 | distribution as part of the hierarchical model. 105 | """ 106 | 107 | 108 | # demos 109 | 110 | def _coerce_if_not_none(value, type_): 111 | if value is None: 112 | return None 113 | 114 | return type_(value) 115 | 116 | 117 | NORMS_ACTIONS_BASELINE = _coerce_if_not_none( 118 | os.environ.get('SCRUPLES_NORMS_ACTIONS_BASELINE'), 119 | str) 120 | """The baseline to use for predicting the actions in the norms demo. 121 | 122 | This constant should be one of the keys from 123 | ``scruples.baselines.resource.FINE_TUNE_LM_BASELINES`` and should correspond to 124 | the ``NORMS_ACTIONS_MODEL`` setting. 125 | """ 126 | 127 | NORMS_ACTIONS_MODEL = _coerce_if_not_none( 128 | os.environ.get('SCRUPLES_NORMS_ACTIONS_MODEL'), 129 | str) 130 | """The path to the model directory to use for predicting the actions. 131 | 132 | The path to the fine-tuned Dirichlet-multinomial likelihood model to use when 133 | predicting the actions in the ``norms`` demo. The chosen directory should be 134 | the result of calling ``.save_pretrained`` on the model instance. See the 135 | transformers_ library for more details. 136 | 137 | .. _transformers: https://github.com/huggingface/transformers 138 | """ 139 | 140 | NORMS_CORPUS_BASELINE = _coerce_if_not_none( 141 | os.environ.get('SCRUPLES_NORMS_CORPUS_BASELINE'), 142 | str) 143 | """The baseline to use for predicting the corpus in the norms demo. 144 | 145 | This constant should be one of the keys from 146 | ``scruples.baselines.corpus.FINE_TUNE_LM_BASELINES`` and should correspond to 147 | the ``NORMS_CORPUS_MODEL`` setting. 148 | """ 149 | 150 | NORMS_CORPUS_MODEL = _coerce_if_not_none( 151 | os.environ.get('SCRUPLES_NORMS_CORPUS_MODEL'), 152 | str) 153 | """The path to the model directory to use for predicting the corpus. 154 | 155 | The path to the fine-tuned Dirichlet-multinomial likelihood model to use when 156 | predicting the corpus in the ``norms`` demo. The chosen directory should be the 157 | result of calling ``.save_pretrained`` on the model instance. See the 158 | transformers_ library for more details. 159 | """ 160 | 161 | NORMS_PREDICT_BATCH_SIZE = _coerce_if_not_none( 162 | os.environ.get('SCRUPLES_NORMS_PREDICT_BATCH_SIZE'), 163 | int) 164 | """The batch size to use for predictions in the ``norms`` demo.""" 165 | 166 | NORMS_GPU_IDS = _coerce_if_not_none( 167 | os.environ.get('SCRUPLES_NORMS_GPU_IDS'), 168 | str) 169 | """The GPU IDs to use for making predictions. 170 | 171 | The GPU IDs to use when making predictions in the ``norms`` demo. In the 172 | environment variable specifying this configuration, the GPU IDs should be 173 | separated by commas (i.e., ``"0,1,2"``). 174 | """ 175 | 176 | 177 | # output and logging 178 | 179 | LOG_FORMAT = '%(asctime)s:%(levelname)s:%(name)s: %(message)s' 180 | """The format string for logging.""" 181 | 182 | TQDM_KWARGS = { 183 | 'ncols': 72, 184 | 'leave': False 185 | } 186 | """Key-word arguments for tqdm progress bars.""" 187 | -------------------------------------------------------------------------------- /src/scruples/vendor/__init__.py: -------------------------------------------------------------------------------- 1 | """Vendored code from various projects.""" 2 | 3 | from . import skopt 4 | -------------------------------------------------------------------------------- /src/scruples/vendor/skopt.py: -------------------------------------------------------------------------------- 1 | # New BSD License 2 | # 3 | # Copyright (c) 2016 - scikit-optimize developers. 4 | # 5 | # All rights reserved. 6 | # 7 | # Redistribution and use in source and binary forms, with or without 8 | # modification, are permitted provided that the following conditions are 9 | # met: 10 | # 11 | # a. Redistributions of source code must retain the above copyright 12 | # notice, this list of conditions and the following disclaimer. 13 | # 14 | # b. Redistributions in binary form must reproduce the above copyright 15 | # notice, this list of conditions and the following disclaimer in the 16 | # documentation and/or other materials provided with the distribution. 17 | # 18 | # c. Neither the name of the scikit-optimize developers nor the names of 19 | # its contributors may be used to endorse or promote products derived from 20 | # this software without specific prior written permission. 21 | # 22 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS 23 | # IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED 24 | # TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 25 | # PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR 26 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 27 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 28 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 29 | # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 30 | # LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 31 | # NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 32 | # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 33 | 34 | """Code vendored from skopt (scikit-optimize). 35 | 36 | See https://github.com/scikit-optimize/scikit-optimize for the source. 37 | """ 38 | 39 | from skopt.utils import dump 40 | 41 | 42 | # vendored from commit: 6740876a6f9ad92c732d394e8534a5236a8d3f84 43 | class CheckpointSaver(object): 44 | """ 45 | Save current state after each iteration with `skopt.dump`. 46 | 47 | 48 | Example usage: 49 | import skopt 50 | 51 | checkpoint_callback = skopt.callbacks.CheckpointSaver("./result.pkl") 52 | skopt.gp_minimize(obj_fun, dims, callback=[checkpoint_callback]) 53 | 54 | Parameters 55 | ---------- 56 | * `checkpoint_path`: location where checkpoint will be saved to; 57 | * `dump_options`: options to pass on to `skopt.dump`, like `compress=9` 58 | """ 59 | def __init__(self, checkpoint_path, **dump_options): 60 | self.checkpoint_path = checkpoint_path 61 | self.dump_options = dump_options 62 | 63 | def __call__(self, res): 64 | """ 65 | Parameters 66 | ---------- 67 | * `res` [`OptimizeResult`, scipy object]: 68 | The optimization as a OptimizeResult object. 69 | """ 70 | dump(res, self.checkpoint_path, **self.dump_options) 71 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | """Tests for the scruples repository.""" 2 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | """Configuration for running tests.""" 2 | 3 | import pytest 4 | 5 | 6 | def pytest_addoption(parser): 7 | parser.addoption( 8 | "--skip-slow", 9 | action="store_true", 10 | help="Skip slow tests, useful for quick checks.") 11 | 12 | 13 | def pytest_collection_modifyitems(config, items): 14 | if config.getoption("--skip-slow"): 15 | skip_slow = pytest.mark.skip( 16 | reason="--skip-slow option is turned on.") 17 | for item in items: 18 | if "slow" in item.keywords: 19 | item.add_marker(skip_slow) 20 | -------------------------------------------------------------------------------- /tests/fixtures/corpus-easy/dev.scruples-corpus.jsonl: -------------------------------------------------------------------------------- 1 | {"id": "id_20", "post_type": "HISTORICAL", "title": "AITA test post", "text": "Label this post nta.", "label_scores": {"AUTHOR": 0, "OTHER": 10, "EVERYBODY": 0, "NOBODY": 0, "INFO": 0}, "label": "OTHER"} 2 | {"id": "id_21", "post_type": "HISTORICAL", "title": "AITA test post", "text": "Label this post yta.", "label_scores": {"AUTHOR": 10, "OTHER": 0, "EVERYBODY": 0, "NOBODY": 0, "INFO": 0}, "label": "AUTHOR"} 3 | {"id": "id_22", "post_type": "HISTORICAL", "title": "AITA test post", "text": "Label this post esh.", "label_scores": {"AUTHOR": 0, "OTHER": 0, "EVERYBODY": 10, "NOBODY": 0, "INFO": 0}, "label": "EVERYBODY"} 4 | {"id": "id_23", "post_type": "HISTORICAL", "title": "AITA test post", "text": "Label this post nah.", "label_scores": {"AUTHOR": 0, "OTHER": 0, "EVERYBODY": 0, "NOBODY": 10, "INFO": 0}, "label": "NOBODY"} 5 | {"id": "id_24", "post_type": "HISTORICAL", "title": "AITA test post", "text": "Label this post info.", "label_scores": {"AUTHOR": 0, "OTHER": 0, "EVERYBODY": 0, "NOBODY": 0, "INFO": 10}, "label": "INFO"} 6 | -------------------------------------------------------------------------------- /tests/fixtures/corpus-easy/test.scruples-corpus.jsonl: -------------------------------------------------------------------------------- 1 | {"id": "id_25", "post_type": "HISTORICAL", "title": "AITA test post", "text": "The label for this post should be nta.", "label_scores": {"AUTHOR": 0, "OTHER": 10, "EVERYBODY": 0, "NOBODY": 0, "INFO": 0}, "label": "OTHER"} 2 | {"id": "id_26", "post_type": "HISTORICAL", "title": "AITA test post", "text": "The label for this post should be yta.", "label_scores": {"AUTHOR": 10, "OTHER": 0, "EVERYBODY": 0, "NOBODY": 0, "INFO": 0}, "label": "AUTHOR"} 3 | {"id": "id_27", "post_type": "HISTORICAL", "title": "AITA test post", "text": "The label for this post should be esh.", "label_scores": {"AUTHOR": 0, "OTHER": 0, "EVERYBODY": 10, "NOBODY": 0, "INFO": 0}, "label": "EVERYBODY"} 4 | {"id": "id_28", "post_type": "HISTORICAL", "title": "AITA test post", "text": "The label for this post should be nah.", "label_scores": {"AUTHOR": 0, "OTHER": 0, "EVERYBODY": 0, "NOBODY": 10, "INFO": 0}, "label": "NOBODY"} 5 | {"id": "id_29", "post_type": "HISTORICAL", "title": "AITA test post", "text": "The label for this post should be info.", "label_scores": {"AUTHOR": 0, "OTHER": 0, "EVERYBODY": 0, "NOBODY": 0, "INFO": 10}, "label": "INFO"} 6 | -------------------------------------------------------------------------------- /tests/fixtures/corpus-easy/train.scruples-corpus.jsonl: -------------------------------------------------------------------------------- 1 | {"id": "id_0", "post_type": "HISTORICAL", "title": "AITA test post", "text": "This post is nta.", "label_scores": {"AUTHOR": 0, "OTHER": 10, "EVERYBODY": 0, "NOBODY": 0, "INFO": 0}, "label": "OTHER"} 2 | {"id": "id_1", "post_type": "HISTORICAL", "title": "AITA test post", "text": "This post is yta.", "label_scores": {"AUTHOR": 10, "OTHER": 0, "EVERYBODY": 0, "NOBODY": 0, "INFO": 0}, "label": "AUTHOR"} 3 | {"id": "id_2", "post_type": "HISTORICAL", "title": "AITA test post", "text": "This post is esh.", "label_scores": {"AUTHOR": 0, "OTHER": 0, "EVERYBODY": 10, "NOBODY": 0, "INFO": 0}, "label": "EVERYBODY"} 4 | {"id": "id_3", "post_type": "HISTORICAL", "title": "AITA test post", "text": "This post is nah.", "label_scores": {"AUTHOR": 0, "OTHER": 0, "EVERYBODY": 0, "NOBODY": 10, "INFO": 0}, "label": "NOBODY"} 5 | {"id": "id_4", "post_type": "HISTORICAL", "title": "AITA test post", "text": "This post is info.", "label_scores": {"AUTHOR": 0, "OTHER": 0, "EVERYBODY": 0, "NOBODY": 0, "INFO": 10}, "label": "INFO"} 6 | {"id": "id_5", "post_type": "HISTORICAL", "title": "AITA test post", "text": "Another post that is nta.", "label_scores": {"AUTHOR": 0, "OTHER": 10, "EVERYBODY": 0, "NOBODY": 0, "INFO": 0}, "label": "OTHER"} 7 | {"id": "id_6", "post_type": "HISTORICAL", "title": "AITA test post", "text": "Another post that is yta.", "label_scores": {"AUTHOR": 10, "OTHER": 0, "EVERYBODY": 0, "NOBODY": 0, "INFO": 0}, "label": "AUTHOR"} 8 | {"id": "id_7", "post_type": "HISTORICAL", "title": "AITA test post", "text": "Another post that is esh.", "label_scores": {"AUTHOR": 0, "OTHER": 0, "EVERYBODY": 10, "NOBODY": 0, "INFO": 0}, "label": "EVERYBODY"} 9 | {"id": "id_8", "post_type": "HISTORICAL", "title": "AITA test post", "text": "Another post that is nah.", "label_scores": {"AUTHOR": 0, "OTHER": 0, "EVERYBODY": 0, "NOBODY": 10, "INFO": 0}, "label": "NOBODY"} 10 | {"id": "id_9", "post_type": "HISTORICAL", "title": "AITA test post", "text": "Another post that is info.", "label_scores": {"AUTHOR": 0, "OTHER": 0, "EVERYBODY": 0, "NOBODY": 0, "INFO": 10}, "label": "INFO"} 11 | {"id": "id_10", "post_type": "HISTORICAL", "title": "AITA test post", "text": "This post's label is nta.", "label_scores": {"AUTHOR": 0, "OTHER": 10, "EVERYBODY": 0, "NOBODY": 0, "INFO": 0}, "label": "OTHER"} 12 | {"id": "id_11", "post_type": "HISTORICAL", "title": "AITA test post", "text": "This post's label is yta.", "label_scores": {"AUTHOR": 10, "OTHER": 0, "EVERYBODY": 0, "NOBODY": 0, "INFO": 0}, "label": "AUTHOR"} 13 | {"id": "id_12", "post_type": "HISTORICAL", "title": "AITA test post", "text": "This post's label is esh.", "label_scores": {"AUTHOR": 0, "OTHER": 0, "EVERYBODY": 10, "NOBODY": 0, "INFO": 0}, "label": "EVERYBODY"} 14 | {"id": "id_13", "post_type": "HISTORICAL", "title": "AITA test post", "text": "This post's label is nah.", "label_scores": {"AUTHOR": 0, "OTHER": 0, "EVERYBODY": 0, "NOBODY": 10, "INFO": 0}, "label": "NOBODY"} 15 | {"id": "id_14", "post_type": "HISTORICAL", "title": "AITA test post", "text": "This post's label is info.", "label_scores": {"AUTHOR": 0, "OTHER": 0, "EVERYBODY": 0, "NOBODY": 0, "INFO": 10}, "label": "INFO"} 16 | {"id": "id_15", "post_type": "HISTORICAL", "title": "AITA test post", "text": "Predict nta.", "label_scores": {"AUTHOR": 0, "OTHER": 10, "EVERYBODY": 0, "NOBODY": 0, "INFO": 0}, "label": "OTHER"} 17 | {"id": "id_16", "post_type": "HISTORICAL", "title": "AITA test post", "text": "Predict yta.", "label_scores": {"AUTHOR": 10, "OTHER": 0, "EVERYBODY": 0, "NOBODY": 0, "INFO": 0}, "label": "AUTHOR"} 18 | {"id": "id_17", "post_type": "HISTORICAL", "title": "AITA test post", "text": "Predict esh.", "label_scores": {"AUTHOR": 0, "OTHER": 0, "EVERYBODY": 10, "NOBODY": 0, "INFO": 0}, "label": "EVERYBODY"} 19 | {"id": "id_18", "post_type": "HISTORICAL", "title": "AITA test post", "text": "Predict nah.", "label_scores": {"AUTHOR": 0, "OTHER": 0, "EVERYBODY": 0, "NOBODY": 10, "INFO": 0}, "label": "NOBODY"} 20 | {"id": "id_19", "post_type": "HISTORICAL", "title": "AITA test post", "text": "Predict info.", "label_scores": {"AUTHOR": 0, "OTHER": 0, "EVERYBODY": 0, "NOBODY": 0, "INFO": 10}, "label": "INFO"} 21 | -------------------------------------------------------------------------------- /tests/fixtures/raw-data/test-comment.jsonl: -------------------------------------------------------------------------------- 1 | {"id":"aaaaaaa","subreddit_id":"t5_aaaaaa","subreddit":"AmItheAsshole","link_id":"t3_aaaaaa","parent_id":"t3_aaaaaa","created_utc":1540000000,"author":"example-user","body":"This comment is for testing.","score":10,"controversiality":0,"gilded":0} 2 | -------------------------------------------------------------------------------- /tests/fixtures/raw-data/test-post-comments.jsonl: -------------------------------------------------------------------------------- 1 | {"id":"aaaaaaa","subreddit_id":"t5_aaaaaa","subreddit":"AmItheAsshole","link_id":"t3_aaaaaa","parent_id":"t3_aaaaaa","created_utc":1541000000,"author":"example-user","body":"NTA, definitely nta.","score":6,"controversiality":0,"gilded":0} 2 | {"id":"aaaaaab","subreddit_id":"t5_aaaaaa","subreddit":"AmItheAsshole","link_id":"t3_aaaaaa","parent_id":"t3_aaaaaa","created_utc":1542000000,"author":"example-user","body":"I think YTA.","score":7,"controversiality":0,"gilded":0} 3 | {"id":"aaaaaac","subreddit_id":"t5_aaaaaa","subreddit":"AmItheAsshole","link_id":"t3_aaaaaa","parent_id":"t3_aaaaaa","created_utc":1543000000,"author":"example-user","body":"I'm gonna have to go with NTA on this one.","score":8,"controversiality":0,"gilded":0} 4 | {"id":"aaaaaad","subreddit_id":"t5_aaaaaa","subreddit":"AmItheAsshole","link_id":"t3_aaaaaa","parent_id":"t3_aaaaaa","created_utc":1544000000,"author":"example-user","body":"Very surprising!","score":9,"controversiality":0,"gilded":0} 5 | {"id":"aaaaaae","subreddit_id":"t5_aaaaaa","subreddit":"AmItheAsshole","link_id":"t3_aaaaaa","parent_id":"t3_aaaaaa","created_utc":1545000000,"author":"AutoModerator","body":"^^^^AUTOMOD ***This is a copy of the above post. It is a record of the post as originally written, in case the post is deleted or edited.***\n\nThe original text.\n\n*I am a bot, and this action was performed automatically. Please [contact the moderators of this subreddit](/message/compose/?to=/r/AmItheAsshole) if you have any questions or concerns.*","score":10,"controversiality":0,"gilded":0} 6 | -------------------------------------------------------------------------------- /tests/fixtures/raw-data/test-post.jsonl: -------------------------------------------------------------------------------- 1 | {"id":"aaaaaa","subreddit_id":"t5_aaaaa","subreddit":"AmItheAsshole","permalink":"/r/AmItheAsshole/comments/aaaaaa/aita_test_post/","domain":"self.AmItheAsshole","url":"https://www.reddit.com/r/AmItheAsshole/comments/aaaaaa/aita_test_post/","created_utc":1540000000,"author":"example-user","title":"AITA test post","selftext":"This post is for testing.","thumbnail":"default","score":10,"num_comments":10,"gilded":0,"retrieved_on":1545000000,"archived":false,"is_self":true,"over_18":false,"stickied":false} 2 | -------------------------------------------------------------------------------- /tests/fixtures/resource-easy/dev.scruples-actions.jsonl: -------------------------------------------------------------------------------- 1 | {"id": "id_16", "actions": [{"id": "action_id_16_0", "description": "The good action."}, {"id": "action_id_16_1", "description": "The bad action."}], "gold_label": 0, "gold_annotations": [1, 0]} 2 | {"id": "id_17", "actions": [{"id": "action_id_17_0", "description": "This is bad!"}, {"id": "action_id_17_1", "description": "This is good!"}], "gold_label": 1, "gold_annotations": [0, 1]} 3 | {"id": "id_18", "actions": [{"id": "action_id_18_0", "description": "This is good!"}, {"id": "action_id_18_1", "description": "The bad action."}], "gold_label": 0, "gold_annotations": [1, 0]} 4 | {"id": "id_19", "actions": [{"id": "action_id_19_0", "description": "This is bad!"}, {"id": "action_id_19_0", "description": "The good action."}], "gold_label": 1, "gold_annotations": [0, 1]} 5 | -------------------------------------------------------------------------------- /tests/fixtures/resource-easy/test.scruples-actions.jsonl: -------------------------------------------------------------------------------- 1 | {"id": "id_20", "actions": [{"id": "action_id_20_0", "description": "Indeed, this describes a very good action."}, {"id": "action_id_20_1", "description": "Indeed, this describes a very bad action."}], "gold_label": 0, "gold_annotations": [1, 0]} 2 | {"id": "id_21", "actions": [{"id": "action_id_21_0", "description": "Always try to do bad."}, {"id": "action_id_21_1", "description": "Always try to do good."}], "gold_label": 1, "gold_annotations": [0, 1]} 3 | {"id": "id_22", "actions": [{"id": "action_id_22_0", "description": "Always try to do good."}, {"id": "action_id_22_1", "description": "Indeed, this describes a very bad action."}], "gold_label": 0, "gold_annotations": [1, 0]} 4 | {"id": "id_23", "actions": [{"id": "action_id_23_0", "description": "Always try to do bad."}, {"id": "action_id_23_1", "description": "Indeed, this describes a very good action."}], "gold_label": 1, "gold_annotations": [0, 1]} 5 | -------------------------------------------------------------------------------- /tests/fixtures/resource-easy/train.scruples-actions.jsonl: -------------------------------------------------------------------------------- 1 | {"id": "id_0", "actions": [{"id": "action_id_0_0", "description": "A good action."}, {"id": "action_id_0_1", "description": "A bad action."}], "gold_label": 0, "gold_annotations": [1, 0]} 2 | {"id": "id_1", "actions": [{"id": "action_id_1_0", "description": "A very bad action."}, {"id": "action_id_1_1", "description": "A very good action."}], "gold_label": 1, "gold_annotations": [0, 1]} 3 | {"id": "id_2", "actions": [{"id": "action_id_2_0", "description": "This action is good."}, {"id": "action_id_2_1", "description": "This action is bad."}], "gold_label": 0, "gold_annotations": [1, 0]} 4 | {"id": "id_3", "actions": [{"id": "action_id_3_0", "description": "bad!"}, {"id": "action_id_3_1", "description": "good!"}], "gold_label": 1, "gold_annotations": [0, 1]} 5 | {"id": "id_4", "actions": [{"id": "action_id_4_0", "description": "Very good."}, {"id": "action_id_4_1", "description": "Very bad."}], "gold_label": 0, "gold_annotations": [1, 0]} 6 | {"id": "id_5", "actions": [{"id": "action_id_5_0", "description": "A most bad thing to do."}, {"id": "action_id_5_1", "description": "A most good thing to do."}], "gold_label": 1, "gold_annotations": [0, 1]} 7 | {"id": "id_6", "actions": [{"id": "action_id_6_0", "description": "Do choose this good thing."}, {"id": "action_id_6_1", "description": "Don't choose this bad thing"}], "gold_label": 0, "gold_annotations": [1, 0]} 8 | {"id": "id_7", "actions": [{"id": "action_id_7_0", "description": "This description is about a bad action."}, {"id": "action_id_7_1", "description": "This description is about a good action."}], "gold_label": 1, "gold_annotations": [0, 1]} 9 | {"id": "id_8", "actions": [{"id": "action_id_8_0", "description": "A good action."}, {"id": "action_id_8_1", "description": "A very bad action."}], "gold_label": 0, "gold_annotations": [1, 0]} 10 | {"id": "id_9", "actions": [{"id": "action_id_9_0", "description": "A bad action."}, {"id": "action_id_9_1", "description": "A very good action."}], "gold_label": 1, "gold_annotations": [0, 1]} 11 | {"id": "id_10", "actions": [{"id": "action_id_10_0", "description": "good!"}, {"id": "action_id_10_1", "description": "This action is bad."}], "gold_label": 0, "gold_annotations": [1, 0]} 12 | {"id": "id_11", "actions": [{"id": "action_id_11_0", "description": "bad!"}, {"id": "action_id_11_1", "description": "This action is good."}], "gold_label": 1, "gold_annotations": [0, 1]} 13 | {"id": "id_12", "actions": [{"id": "action_id_12_0", "description": "A most good thing to do."}, {"id": "action_id_12_1", "description": "Very bad."}], "gold_label": 0, "gold_annotations": [1, 0]} 14 | {"id": "id_13", "actions": [{"id": "action_id_13_0", "description": "A most bad thing to do."}, {"id": "action_id_13_1", "description": "Very good."}], "gold_label": 1, "gold_annotations": [0, 1]} 15 | {"id": "id_14", "actions": [{"id": "action_id_14_0", "description": "This description is about a good action."}, {"id": "action_id_14_1", "description": "Don't choose this bad thing."}], "gold_label": 0, "gold_annotations": [1, 0]} 16 | {"id": "id_15", "actions": [{"id": "action_id_15_0", "description": "This description is about a bad action."}, {"id": "action_id_15_1", "description": "Do choose this good thing."}], "gold_label": 1, "gold_annotations": [0, 1]} 17 | -------------------------------------------------------------------------------- /tests/scruples/__init__.py: -------------------------------------------------------------------------------- 1 | """Tests for the ``scruples`` package.""" 2 | -------------------------------------------------------------------------------- /tests/scruples/analysis/__init__.py: -------------------------------------------------------------------------------- 1 | """Tests for ``scruples.analysis``.""" 2 | -------------------------------------------------------------------------------- /tests/scruples/baselines/__init__.py: -------------------------------------------------------------------------------- 1 | """Tests for ``scruples.baselines``.""" 2 | -------------------------------------------------------------------------------- /tests/scruples/baselines/test_labels.py: -------------------------------------------------------------------------------- 1 | """Tests for scruples.baselines.labels.""" 2 | 3 | import unittest 4 | 5 | import numpy as np 6 | 7 | from scruples.baselines import labels 8 | 9 | 10 | class PriorBaselineTestCase(unittest.TestCase): 11 | """Test scruples.baselines.labels.PriorBaseline.""" 12 | 13 | FEATURES = [ 14 | [0, 1], 15 | [0, 1], 16 | [1, 0] 17 | ] 18 | LABELS = ['a', 'a', 'b'] 19 | 20 | def test_predicts_most_frequent_label(self): 21 | baseline = labels.PriorBaseline 22 | baseline.fit(X=self.FEATURES, y=self.LABELS) 23 | predictions = baseline.predict(self.FEATURES) 24 | 25 | self.assertEqual(predictions.tolist(), ['a', 'a', 'a']) 26 | 27 | def test_predicts_probabilities_by_class_prior(self): 28 | baseline = labels.PriorBaseline 29 | baseline.fit(X=self.FEATURES, y=self.LABELS) 30 | probabilities = baseline.predict_proba(self.FEATURES) 31 | 32 | self.assertAlmostEqual( 33 | probabilities.tolist(), 34 | [ 35 | [2./3., 1./3.], 36 | [2./3., 1./3.], 37 | [2./3., 1./3.] 38 | ]) 39 | 40 | 41 | class StratifiedBaselineTestCase(unittest.TestCase): 42 | """Test scruples.baselines.labels.StratifiedBaseline.""" 43 | 44 | FEATURES = [ 45 | [0, 1], 46 | [0, 1], 47 | [1, 0] 48 | ] 49 | LABELS = ['a', 'a', 'b'] 50 | 51 | # the number of trials to perform in statistically testing that the 52 | # predictions adhere to the label frequencies from the training data 53 | N_TRIALS = 10000 54 | 55 | def test_predicts_random_label_by_class_probability(self): 56 | baseline = labels.StratifiedBaseline 57 | baseline.fit(X=self.FEATURES, y=self.LABELS) 58 | predictions = baseline.predict([[0, 0] for _ in range(self.N_TRIALS)]) 59 | 60 | elements, counts = np.unique(predictions, return_counts=True) 61 | 62 | # test that all labels are predicted 63 | self.assertEqual(elements.tolist(), ['a', 'b']) 64 | 65 | # check that the mean counts for each label are within 5 66 | # standard deviations of their expectations 67 | std = ((2./3. * 1./3.) / self.N_TRIALS)**0.5 68 | a_mean = counts[0] / self.N_TRIALS 69 | b_mean = counts[1] / self.N_TRIALS 70 | self.assertGreater(a_mean, 2./3. - 5 * std) 71 | self.assertLess(a_mean, 2./3. + 5 * std) 72 | self.assertGreater(b_mean, 1./3. - 5 * std) 73 | self.assertLess(b_mean, 1./3. + 5 * std) 74 | -------------------------------------------------------------------------------- /tests/scruples/baselines/test_linear.py: -------------------------------------------------------------------------------- 1 | """Tests for scruples.baselines.linear.""" 2 | 3 | import unittest 4 | 5 | from scruples.baselines import linear 6 | from .utils import BaselineTestMixin 7 | 8 | 9 | class LogististicRegressionBaselineTestCase( 10 | BaselineTestMixin, 11 | unittest.TestCase 12 | ): 13 | """Test the logistic regression on bag-of-ngrams baseline.""" 14 | 15 | BASELINE_MODEL = linear.LogisticRegressionBaseline 16 | BASELINE_HYPER_PARAMETERS = linear.LOGISTIC_REGRESSION_HYPER_PARAMETERS 17 | DATASET = 'corpus' 18 | 19 | 20 | class LogististicRankerBaselineTestCase( 21 | BaselineTestMixin, 22 | unittest.TestCase 23 | ): 24 | """Test the logistic ranker on bag-of-ngrams baseline.""" 25 | 26 | BASELINE_MODEL = linear.LogisticRankerBaseline 27 | BASELINE_HYPER_PARAMETERS = linear.LOGISTIC_RANKER_HYPER_PARAMETERS 28 | DATASET = 'resource' 29 | -------------------------------------------------------------------------------- /tests/scruples/baselines/test_loss.py: -------------------------------------------------------------------------------- 1 | """Tests for scruples.baselines.loss.""" 2 | 3 | import unittest 4 | 5 | import numpy as np 6 | import torch as th 7 | 8 | from scruples.baselines import loss 9 | 10 | 11 | class TestSoftCrossEntropyLoss(unittest.TestCase): 12 | """Test SoftCrossEntropyLoss.""" 13 | 14 | def test_forward(self): 15 | f = loss.SoftCrossEntropyLoss() 16 | 17 | self.assertAlmostEqual( 18 | f( 19 | th.Tensor([[0., 0.]]), 20 | th.Tensor([[0.5, 0.5]]) 21 | ).item(), 22 | 0.6931471805599453) 23 | self.assertAlmostEqual( 24 | f( 25 | th.Tensor([[0., 0.], [-1.0986123, 1.0986123]]), 26 | th.Tensor([[0.3, 0.7], [0.6, 0.4]]) 27 | ).item(), 28 | 1.0584212213097515) 29 | 30 | 31 | class TestDirichletMultinomialLoss(unittest.TestCase): 32 | """Test DirichletMultinomialLoss.""" 33 | 34 | def test_forward(self): 35 | f = loss.DirichletMultinomialLoss() 36 | 37 | # when [a, b] goes to [0, 0], the dirichlet prior becomes a 38 | # Bernoulli distribution with p = 0.5, where there's a 50/50 39 | # chance that the label always comes up 0 or always comes up 1 40 | self.assertAlmostEqual( 41 | f( 42 | th.Tensor([[np.log(1e-20), np.log(1e-20)]]), 43 | th.Tensor([[3, 0]]) 44 | ).item(), 45 | 0.6931471805599453, 46 | places=4) 47 | # test other cases 48 | self.assertAlmostEqual( 49 | f( 50 | th.Tensor([[np.log(1.), np.log(1.)]]), 51 | th.Tensor([[0.5, 0.5]]) 52 | ).item(), 53 | 0.9347116558304358) 54 | self.assertAlmostEqual( 55 | f( 56 | th.Tensor( 57 | [[np.log(1.), np.log(1.)], 58 | [np.log(0.5), np.log(2.)]] 59 | ), 60 | th.Tensor([[0.3, 0.7], [0.6, 0.4]]) 61 | ).item(), 62 | 1.1093992405423625) 63 | -------------------------------------------------------------------------------- /tests/scruples/baselines/test_metrics.py: -------------------------------------------------------------------------------- 1 | """Tests for scruples.baselines.metrics.""" 2 | -------------------------------------------------------------------------------- /tests/scruples/baselines/test_naivebayes.py: -------------------------------------------------------------------------------- 1 | """Tests for scruples.baselines.naivebayes.""" 2 | 3 | import unittest 4 | 5 | from scruples.baselines import naivebayes 6 | from .utils import BaselineTestMixin 7 | 8 | 9 | class BernoulliNBBaselineTestCase( 10 | BaselineTestMixin, 11 | unittest.TestCase 12 | ): 13 | """Test the bernoulli naive bayes on bag-of-ngrams baseline.""" 14 | 15 | BASELINE_MODEL = naivebayes.BernoulliNBBaseline 16 | BASELINE_HYPER_PARAMETERS = naivebayes.BERNOULLINB_HYPER_PARAMETERS 17 | DATASET = 'corpus' 18 | 19 | 20 | class MultinomialNBBaselineTestCase( 21 | BaselineTestMixin, 22 | unittest.TestCase 23 | ): 24 | """Test the multinomial naive bayes on bag-of-ngrams baseline.""" 25 | 26 | BASELINE_MODEL = naivebayes.MultinomialNBBaseline 27 | BASELINE_HYPER_PARAMETERS = naivebayes.MULTINOMIALNB_HYPER_PARAMETERS 28 | DATASET = 'corpus' 29 | 30 | 31 | class ComplementNBBaselineTestCase( 32 | BaselineTestMixin, 33 | unittest.TestCase 34 | ): 35 | """Test the complement naive bayes on bag-of-ngrams baseline.""" 36 | 37 | BASELINE_MODEL = naivebayes.ComplementNBBaseline 38 | BASELINE_HYPER_PARAMETERS = naivebayes.COMPLEMENTNB_HYPER_PARAMETERS 39 | DATASET = 'corpus' 40 | -------------------------------------------------------------------------------- /tests/scruples/baselines/test_utils.py: -------------------------------------------------------------------------------- 1 | """Tests for scruples.baselines.utils.""" 2 | 3 | import unittest 4 | from unittest.mock import Mock 5 | 6 | import numpy as np 7 | import pandas as pd 8 | 9 | from scruples.baselines import utils 10 | 11 | 12 | class ConcatTitleAndTextTestCase(unittest.TestCase): 13 | """Test concat_title_and_text.""" 14 | 15 | def test_concatenates_title_and_text_columns(self): 16 | features = pd.DataFrame({ 17 | 'title': ['The Title A', 'The Title B'], 18 | 'text': ['The text A.', 'The text B.'] 19 | }) 20 | 21 | self.assertEqual( 22 | utils.concat_title_and_text(features).tolist(), 23 | ['The Title A\nThe text A.', 'The Title B\nThe text B.']) 24 | 25 | 26 | class DirichletMultinomialTestCase(unittest.TestCase): 27 | """Test dirichlet_multinomial.""" 28 | 29 | def test_dirichlet_multinomial(self): 30 | self.assertTrue(np.allclose( 31 | utils.dirichlet_multinomial( 32 | [[np.log(0.1), np.log(0.1)]] 33 | ).tolist(), 34 | [[0.5, 0.5]])) 35 | self.assertTrue(np.allclose( 36 | utils.dirichlet_multinomial( 37 | [[np.log(1.), np.log(1.)], 38 | [np.log(1.), np.log(4.)]] 39 | ).tolist(), 40 | [[0.5, 0.5], [0.2, 0.8]])) 41 | self.assertTrue(np.allclose( 42 | utils.dirichlet_multinomial( 43 | [[np.log(1.), np.log(1.), np.log(1.)]] 44 | ).tolist(), 45 | [[1/3., 1/3., 1/3.]])) 46 | self.assertTrue(np.allclose( 47 | utils.dirichlet_multinomial( 48 | [[np.log(1.), np.log(2.), np.log(1.)], 49 | [np.log(2.), np.log(3.), np.log(5.)]] 50 | ).tolist(), 51 | [[0.25, 0.5, 0.25], [0.2, 0.3, 0.5]])) 52 | 53 | 54 | class ResourceTransformerTestCase(unittest.TestCase): 55 | """Test ResourceTransformer.""" 56 | 57 | def test_set_params(self): 58 | # mock out the inputs 59 | transformer = Mock() 60 | transformer.get_params.return_value = {'x': None} 61 | second_transformer = Mock() 62 | second_transformer.get_params.return_value = {'x': None} 63 | third_transformer = Mock() 64 | third_transformer.get_params.return_value = {'x': None} 65 | 66 | resource_transformer = utils.ResourceTransformer( 67 | transformer=transformer) 68 | 69 | # test setting params on the transformer attribute 70 | resource_transformer.set_params(transformer__x=1) 71 | 72 | transformer.set_params.assert_called() 73 | transformer.set_params.assert_called_with(x=1) 74 | 75 | # test setting params on the ResourceTransformer 76 | self.assertNotEqual( 77 | resource_transformer.transformer, 78 | second_transformer) 79 | 80 | resource_transformer.set_params( 81 | transformer=second_transformer) 82 | 83 | self.assertEqual( 84 | resource_transformer.transformer, 85 | second_transformer) 86 | 87 | # test setting params on the ResourceTransformer and the new 88 | # transformer at the same time 89 | self.assertNotEqual( 90 | resource_transformer.transformer, 91 | third_transformer) 92 | 93 | resource_transformer.set_params( 94 | transformer=third_transformer, 95 | transformer__x='foo') 96 | 97 | self.assertEqual( 98 | resource_transformer.transformer, 99 | third_transformer) 100 | 101 | third_transformer.set_params.asset_called() 102 | third_transformer.set_params.assert_called_with(x='foo') 103 | 104 | def test_fit(self): 105 | # create the data 106 | X = pd.DataFrame([ 107 | {'action0': 1, 'action1': 4}, 108 | {'action0': 2, 'action1': 5}, 109 | {'action0': 3, 'action1': 6}, 110 | ]) 111 | 112 | # mock arguments to instantiate ResourceTransformer 113 | transformer = Mock() 114 | 115 | # create the ResourceTransformer instance 116 | resource_transformer = utils.ResourceTransformer( 117 | transformer=transformer) 118 | 119 | # run tests 120 | self.assertIsInstance( 121 | resource_transformer.fit(X), 122 | utils.ResourceTransformer) 123 | 124 | transformer.fit.assert_called() 125 | self.assertEqual( 126 | transformer.fit.call_args[0][0].tolist(), 127 | pd.concat([X['action0'], X['action1']]).tolist()) 128 | 129 | def test_transform(self): 130 | # create the data 131 | X = pd.DataFrame([ 132 | {'action0': 1, 'action1': 4}, 133 | {'action0': 2, 'action1': 4}, 134 | {'action0': 3, 'action1': 4} 135 | ]) 136 | 137 | X_transformed = [6, 4, 2] 138 | 139 | class TimesTwoTransformer(object): 140 | def fit(self, X, y = None): 141 | return self 142 | 143 | def transform(self, X): 144 | return 2 * X 145 | 146 | resources = ['action0', 'action1'] 147 | transformer = TimesTwoTransformer() 148 | 149 | resource_transformer = utils.ResourceTransformer( 150 | transformer=transformer) 151 | 152 | resource_transformer.fit(X) 153 | 154 | self.assertEqual( 155 | resource_transformer.transform(X).tolist(), 156 | X_transformed) 157 | -------------------------------------------------------------------------------- /tests/scruples/baselines/utils.py: -------------------------------------------------------------------------------- 1 | """Utilities for testing baselines.""" 2 | 3 | import os 4 | import tempfile 5 | 6 | import pandas as pd 7 | import pytest 8 | from sklearn import metrics 9 | from skopt import BayesSearchCV 10 | 11 | from scruples import settings as scruples_settings 12 | from scruples.dataset import readers 13 | from ... import settings, utils 14 | 15 | 16 | # classes 17 | 18 | class BaselineTestMixin: 19 | """Test a baseline model by verifying it solves an easy dataset. 20 | 21 | Attributes 22 | ---------- 23 | BASELINE_MODEL : sklearn.base.BaseEstimator 24 | The baseline model to test. 25 | BASELINE_HYPER_PARAMETERS : Dict 26 | The dictionary defining the hyper-parameter search space for the 27 | baseline model. 28 | DATASET : str 29 | The dataset against which the baseline should be run. Must be 30 | either ``"resource"`` or ``"corpus"``. 31 | 32 | Examples 33 | -------- 34 | To create a test case for a baseline model, inherit from this class 35 | along with ``unittest.TestCase`` and provide the ``BASELINE_MODEL`` 36 | ``BASELINE_HYPER_PARAMETERS``, and ``DATASET`` class attributes:: 37 | 38 | class LogisticRegressionBaselineTestCase( 39 | BaselineTestMixin, 40 | unittest.TestCase 41 | ): 42 | '''Test the logistic regression baseline.''' 43 | 44 | BASELINE_MODEL = LogisticRegressionBaseline 45 | BASELINE_HYPER_PARAMETERS = LOGISTIC_REGRESSION_HYPER_PARAMS 46 | DATASET = 'corpus' 47 | 48 | """ 49 | 50 | BASELINE_MODEL = None 51 | BASELINE_HYPER_PARAMETERS = None 52 | DATASET = None 53 | 54 | def setUp(self): 55 | super().setUp() 56 | 57 | # validate the class 58 | 59 | if self.BASELINE_MODEL is None: 60 | raise ValueError( 61 | 'Subclasses of BaselineTestMixin must provide a' 62 | ' BASELINE_MODEL class attribute.') 63 | 64 | if self.BASELINE_HYPER_PARAMETERS is None: 65 | raise ValueError( 66 | 'Subclasses of BaselineTestMixin must provide a' 67 | ' BASELINE_HYPER_PARAMETERS class attribute.') 68 | 69 | if self.DATASET is None: 70 | raise ValueError( 71 | 'Subclasses of BaselineTestMixin must provide a DATASET' 72 | ' class attribute.') 73 | 74 | if self.DATASET not in ['resource', 'corpus']: 75 | raise ValueError( 76 | 'The DATASET class attribute must either be' 77 | ' "resource", or "corpus".') 78 | 79 | # copy the dataset fixture from the package to disk 80 | 81 | if self.DATASET == 'resource': 82 | Reader = readers.ScruplesResource 83 | fixture_path = settings.RESOURCE_EASY_DIR 84 | split_filename_template =\ 85 | scruples_settings.RESOURCE_FILENAME_TEMPLATE 86 | elif self.DATASET == 'corpus': 87 | Reader = readers.ScruplesCorpus 88 | fixture_path = settings.CORPUS_EASY_DIR 89 | split_filename_template =\ 90 | scruples_settings.CORPUS_FILENAME_TEMPLATE 91 | 92 | self.temp_dir = tempfile.TemporaryDirectory() 93 | 94 | for split in Reader.SPLITS: 95 | split_filename = split_filename_template.format(split=split) 96 | utils.copy_pkg_resource_to_disk( 97 | pkg='tests', 98 | src=os.path.join(fixture_path, split_filename), 99 | dst=os.path.join(self.temp_dir.name, split_filename)) 100 | 101 | # load the dataset 102 | 103 | self.dataset = Reader(data_dir=self.temp_dir.name) 104 | 105 | def tearDown(self): 106 | self.temp_dir.cleanup() 107 | 108 | @pytest.mark.slow 109 | def test_it_solves_scruples_easy_when_untuned(self): 110 | baseline = self.BASELINE_MODEL 111 | 112 | # train the model 113 | _, train_features, train_labels, train_label_scores =\ 114 | self.dataset.train 115 | baseline.fit(train_features, train_labels) 116 | 117 | # predict with the model on dev 118 | _, dev_features, dev_labels, dev_label_scores =\ 119 | self.dataset.dev 120 | predictions = baseline.predict(dev_features) 121 | 122 | # check that the accuracy is 100% 123 | self.assertEqual( 124 | metrics.accuracy_score( 125 | y_true=dev_labels, 126 | y_pred=predictions), 127 | 1.) 128 | 129 | @pytest.mark.slow 130 | def test_it_solves_scruples_easy_when_tuned(self): 131 | baseline = BayesSearchCV( 132 | self.BASELINE_MODEL, 133 | self.BASELINE_HYPER_PARAMETERS, 134 | n_iter=16, 135 | n_points=2, 136 | cv=4, 137 | n_jobs=1, 138 | refit=True) 139 | 140 | # train the model, tuning hyper-parameters 141 | _, train_features, train_labels, train_label_scores =\ 142 | self.dataset.train 143 | baseline.fit(train_features, train_labels) 144 | 145 | # predict with the model on dev 146 | _, dev_features, dev_labels, dev_label_scores =\ 147 | self.dataset.dev 148 | predictions = baseline.predict(dev_features) 149 | 150 | # check that the accuracy is 100% 151 | self.assertEqual( 152 | metrics.accuracy_score( 153 | y_true=dev_labels, 154 | y_pred=predictions), 155 | 1.) 156 | -------------------------------------------------------------------------------- /tests/scruples/data/__init__.py: -------------------------------------------------------------------------------- 1 | """Tests for ``scruples.data``.""" 2 | -------------------------------------------------------------------------------- /tests/scruples/data/test_utils.py: -------------------------------------------------------------------------------- 1 | """Tests for scruples.data.utils.""" 2 | 3 | import unittest 4 | 5 | import attr 6 | 7 | from scruples.data import utils 8 | 9 | 10 | class InstantiateAttrsWithExtraKwargsTestCase(unittest.TestCase): 11 | """Test instantiate_attrs_with_extra_kwargs.""" 12 | 13 | class NonAttrsClass: 14 | def __init__(self, foo): 15 | self.foo = foo 16 | 17 | @attr.s 18 | class AttrsClass: 19 | foo = attr.ib() 20 | 21 | def test_raises_error_on_non_attrs_classes(self): 22 | with self.assertRaisesRegex( 23 | ValueError, 24 | r'.* must be decorated with @attr\.s' 25 | ): 26 | utils.instantiate_attrs_with_extra_kwargs( 27 | self.NonAttrsClass, foo=1) 28 | 29 | def test_instantiates_class_with_no_extra_kwargs(self): 30 | instance = utils.instantiate_attrs_with_extra_kwargs( 31 | self.AttrsClass, foo=1) 32 | 33 | self.assertEqual(instance.foo, 1) 34 | 35 | def test_instantiates_class_with_extra_kwargs(self): 36 | instance = utils.instantiate_attrs_with_extra_kwargs( 37 | self.AttrsClass, foo='a', bar='b') 38 | 39 | self.assertEqual(instance.foo, 'a') 40 | 41 | 42 | class CachedPropertyTestCase(unittest.TestCase): 43 | """Test cached_property.""" 44 | 45 | def test_makes_method_into_property(self): 46 | class Foo: 47 | @utils.cached_property 48 | def bar(self): 49 | return 1 50 | 51 | foo = Foo() 52 | 53 | self.assertEqual(foo.bar, 1) 54 | 55 | def test_caches_property_from_method(self): 56 | class Foo: 57 | @utils.cached_property 58 | def bar(self): 59 | # this method (if not cached) will increment it's return 60 | # value based on the number of times it has been called 61 | self.num_calls = 1 + getattr(self, 'num_calls', 0) 62 | 63 | return self.num_calls 64 | 65 | foo = Foo() 66 | 67 | # calling bar once should return 1 68 | self.assertEqual(foo.bar, 1) 69 | # if bar is called more than once (i.e., the method is not 70 | # cached) then it will return something greater than one 71 | self.assertEqual(foo.bar, 1) 72 | -------------------------------------------------------------------------------- /tests/scruples/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | """Tests for ``scruples.dataset``.""" 2 | -------------------------------------------------------------------------------- /tests/scruples/demos/__init__.py: -------------------------------------------------------------------------------- 1 | """Tests for ``scruples.demos``.""" 2 | -------------------------------------------------------------------------------- /tests/scruples/demos/norms/__init__.py: -------------------------------------------------------------------------------- 1 | """Tests for ``scruples.demos.norms``.""" 2 | -------------------------------------------------------------------------------- /tests/scruples/demos/norms/test_utils.py: -------------------------------------------------------------------------------- 1 | """Tests for scruples.demos.norms.utils.""" 2 | 3 | import unittest 4 | from unittest import mock 5 | 6 | from scruples.demos.norms import utils 7 | 8 | 9 | class PredictionDatasetTestCase(unittest.TestCase): 10 | """Test scruples.demos.norms.utils.PredictionDataset.""" 11 | 12 | def setUp(self): 13 | self.features = [[1, 2], [1, 0]] 14 | self.transform = lambda x: x 15 | 16 | self.dataset = utils.PredictionDataset( 17 | features=self.features, 18 | transform=self.transform) 19 | 20 | def test___init__(self): 21 | self.assertEqual(self.dataset.features, self.features) 22 | self.assertEqual(self.dataset.transform, self.transform) 23 | 24 | def test___len__(self): 25 | self.assertEqual(len(self.dataset), len(self.features)) 26 | 27 | def test___get_item__(self): 28 | for i in range(len(self.features)): 29 | self.assertEqual(self.dataset[i], self.features[i]) 30 | 31 | def test_it_applies_transform(self): 32 | features = self.features 33 | mock_transform = mock.MagicMock(return_value='foo') 34 | 35 | dataset = utils.PredictionDataset( 36 | features=features, 37 | transform=mock_transform) 38 | 39 | self.assertEqual(dataset[1], 'foo') 40 | mock_transform.assert_called_with(features[1]) 41 | -------------------------------------------------------------------------------- /tests/scruples/demos/scoracle/__init__.py: -------------------------------------------------------------------------------- 1 | """Tests for ``scruples.demos.scoracle``.""" 2 | -------------------------------------------------------------------------------- /tests/scruples/extraction/__init__.py: -------------------------------------------------------------------------------- 1 | """Tests for ``scruples.extraction``.""" 2 | -------------------------------------------------------------------------------- /tests/scruples/extraction/test_base.py: -------------------------------------------------------------------------------- 1 | """Tests for scruples.extraction.base.""" 2 | 3 | import unittest 4 | 5 | from scruples.extraction import base 6 | 7 | 8 | class LoggedCallableTestCase(unittest.TestCase): 9 | """Test scruples.extraction.base.LoggedCallable.""" 10 | 11 | # classes for testing LoggedCallable 12 | 13 | class Add1(base.LoggedCallable): 14 | def apply(self, y): 15 | return y + 1 16 | 17 | class AddX(base.LoggedCallable): 18 | def __init__(self, x, *args, **kwargs): 19 | super().__init__(*args, **kwargs) 20 | 21 | self.x = x 22 | 23 | def apply(self, y): 24 | return self.x + y 25 | 26 | # tests 27 | 28 | def test_apply_raises_not_implemented(self): 29 | logged_callable = base.LoggedCallable() 30 | 31 | with self.assertRaises(NotImplementedError): 32 | logged_callable.apply() 33 | 34 | with self.assertRaises(NotImplementedError): 35 | logged_callable() 36 | 37 | def test_subclass_without_args(self): 38 | add1 = self.Add1() 39 | 40 | self.assertEqual(add1(0), 1) 41 | self.assertEqual(add1(1), 2) 42 | self.assertEqual(add1(5), 6) 43 | 44 | self.assertEqual( 45 | add1.call_log, 46 | [ 47 | ( 48 | # inputs 49 | ( 50 | # args 51 | (0,), 52 | # kwargs 53 | {} 54 | ), 55 | # outputs 56 | 1 57 | ), 58 | ( 59 | # inputs 60 | ( 61 | # args 62 | (1,), 63 | # kwargs 64 | {} 65 | ), 66 | # outputs 67 | 2 68 | ), 69 | ( 70 | # inputs 71 | ( 72 | # args 73 | (5,), 74 | # kwargs 75 | {} 76 | ), 77 | # outputs 78 | 6 79 | ) 80 | ]) 81 | 82 | def test_subclass_with_args(self): 83 | add5 = self.AddX(x=5) 84 | 85 | self.assertEqual(add5(0), 5) 86 | self.assertEqual(add5(1), 6) 87 | self.assertEqual(add5(5), 10) 88 | 89 | self.assertEqual( 90 | add5.call_log, 91 | [ 92 | ( 93 | # inputs 94 | ( 95 | # args 96 | (0,), 97 | # kwargs 98 | {} 99 | ), 100 | # outputs 101 | 5 102 | ), 103 | ( 104 | # inputs 105 | ( 106 | # args 107 | (1,), 108 | # kwargs 109 | {} 110 | ), 111 | # outputs 112 | 6 113 | ), 114 | ( 115 | # inputs 116 | ( 117 | # args 118 | (5,), 119 | # kwargs 120 | {} 121 | ), 122 | # outputs 123 | 10 124 | ) 125 | ]) 126 | 127 | def test_when_log_calls_is_false(self): 128 | add1 = self.Add1(log_calls=False) 129 | 130 | self.assertEqual(add1(0), 1) 131 | self.assertEqual(add1(1), 2) 132 | self.assertEqual(add1(5), 6) 133 | 134 | self.assertEqual(add1.call_log, None) 135 | 136 | add5 = self.AddX(x=5, log_calls=False) 137 | 138 | self.assertEqual(add5(0), 5) 139 | self.assertEqual(add5(1), 6) 140 | self.assertEqual(add5(5), 10) 141 | 142 | self.assertEqual(add5.call_log, None) 143 | 144 | 145 | class CaseTestCase(unittest.TestCase): 146 | """Test scruples.extraction.base.Case.""" 147 | 148 | class SomeToNoneCase(base.Case): 149 | def match(self, x): 150 | return (x, x is not None) 151 | 152 | def transform(self, x): 153 | return None 154 | 155 | def filter(self, x): 156 | return False 157 | 158 | # tests 159 | 160 | def test_subclass_when_case_matches(self): 161 | self.assertEqual( 162 | self.SomeToNoneCase()(1), 163 | (None, True)) 164 | self.assertEqual( 165 | self.SomeToNoneCase()('a'), 166 | (None, True)) 167 | self.assertEqual( 168 | self.SomeToNoneCase()(''), 169 | (None, True)) 170 | 171 | def test_subclass_when_case_does_not_match(self): 172 | self.assertEqual( 173 | self.SomeToNoneCase()(None), 174 | (None, False)) 175 | -------------------------------------------------------------------------------- /tests/scruples/extraction/test_transformers.py: -------------------------------------------------------------------------------- 1 | """Tests for scruples.extraction.transformers.""" 2 | 3 | import unittest 4 | 5 | from scruples.extraction import transformers 6 | 7 | 8 | class GerundifyingTransformerTestCase(unittest.TestCase): 9 | """Test scruples.extraction.transformers.GerundifyingTransformer.""" 10 | 11 | def test_leaves_gerund_phrases_alone(self): 12 | f = transformers.GerundifyingTransformer() 13 | 14 | self.assertEqual( 15 | f('Running down the street'), 16 | 'running down the street') 17 | self.assertEqual( 18 | f('Running down the street and waving goodbye to my friend.'), 19 | 'running down the street and waving goodbye to my friend.') 20 | 21 | def test_transforms_infinitives(self): 22 | f = transformers.GerundifyingTransformer() 23 | 24 | # simple case 25 | self.assertEqual( 26 | f('to run through a stop light'), 27 | 'running through a stop light') 28 | self.assertEqual( 29 | f('to wave to a friend.'), 30 | 'waving to a friend.') 31 | # coordinated verbs 32 | self.assertEqual( 33 | f('to run through and totally ignore a stop light'), 34 | 'running through and totally ignoring a stop light') 35 | self.assertEqual( 36 | f('to wave to a friend and say goodbye.'), 37 | 'waving to a friend and saying goodbye.') 38 | 39 | def test_transforms_I_phrases(self): 40 | f = transformers.GerundifyingTransformer() 41 | 42 | # simple case 43 | self.assertEqual( 44 | f('I ran through a stop light'), 45 | 'running through a stop light') 46 | self.assertEqual( 47 | f('I waved to a friend.'), 48 | 'waving to a friend.') 49 | # coordinated verbs 50 | self.assertEqual( 51 | f('I ran through and totally ignored a stop light'), 52 | 'running through and totally ignoring a stop light') 53 | self.assertEqual( 54 | f('I waved to a friend and said goodbye.'), 55 | 'waving to a friend and saying goodbye.') 56 | 57 | def test_drops_auxiliary_verbs(self): 58 | f = transformers.GerundifyingTransformer() 59 | 60 | # will 61 | # simple case 62 | self.assertEqual( 63 | f('I will run through a stop light'), 64 | 'running through a stop light') 65 | self.assertEqual( 66 | f('I won\'t run through a stop light'), 67 | 'not running through a stop light') 68 | self.assertEqual( 69 | f('I will wave to a friend.'), 70 | 'waving to a friend.') 71 | self.assertEqual( 72 | f('I won\'t wave to a friend.'), 73 | 'not waving to a friend.') 74 | # coordinated verbs 75 | self.assertEqual( 76 | f('I will run through and totally ignore a stop light'), 77 | 'running through and totally ignoring a stop light') 78 | self.assertEqual( 79 | f('I won\'t run through and totally ignore a stop light'), 80 | 'not running through and totally ignoring a stop light') 81 | self.assertEqual( 82 | f('I will wave to a friend and say goodbye.'), 83 | 'waving to a friend and saying goodbye.') 84 | self.assertEqual( 85 | f('I won\'t wave to a friend and say goodbye.'), 86 | 'not waving to a friend and saying goodbye.') 87 | # do 88 | # simple case 89 | self.assertEqual( 90 | f('I did run through a stop light'), 91 | 'running through a stop light') 92 | self.assertEqual( 93 | f('I didn\'t run through a stop light'), 94 | 'not running through a stop light') 95 | self.assertEqual( 96 | f('I did wave to a friend.'), 97 | 'waving to a friend.') 98 | self.assertEqual( 99 | f('I didn\'t wave to a friend.'), 100 | 'not waving to a friend.') 101 | # coordinated verbs 102 | self.assertEqual( 103 | f('I did run through and totally ignore a stop light'), 104 | 'running through and totally ignoring a stop light') 105 | self.assertEqual( 106 | f('I didn\'t run through and totally ignore a stop light'), 107 | 'not running through and totally ignoring a stop light') 108 | self.assertEqual( 109 | f('I did wave to a friend and say goodbye.'), 110 | 'waving to a friend and saying goodbye.') 111 | self.assertEqual( 112 | f('I didn\'t wave to a friend and say goodbye.'), 113 | 'not waving to a friend and saying goodbye.') 114 | 115 | def test_handles_auxiliary_forms_of_to_be(self): 116 | f = transformers.GerundifyingTransformer() 117 | 118 | # to be when it is _not_ auxiliary 119 | self.assertEqual( 120 | f("I'm mad at my friend"), 121 | 'being mad at my friend') 122 | self.assertEqual( 123 | f("I'm happy all the time."), 124 | 'being happy all the time.') 125 | # to be when it is auxiliary 126 | self.assertEqual( 127 | f("I'm thinking about making some coffee"), 128 | 'thinking about making some coffee') 129 | self.assertEqual( 130 | f("I'm running for president."), 131 | 'running for president.') 132 | 133 | def test_handles_Im(self): 134 | f = transformers.GerundifyingTransformer() 135 | 136 | self.assertEqual( 137 | f('Im happy to see you'), 138 | 'being happy to see you') 139 | self.assertEqual( 140 | f('Im sure it is.'), 141 | 'being sure it is.') 142 | -------------------------------------------------------------------------------- /tests/scruples/test_settings.py: -------------------------------------------------------------------------------- 1 | """Tests for scruples.settings.""" 2 | -------------------------------------------------------------------------------- /tests/settings.py: -------------------------------------------------------------------------------- 1 | """Settings for the tests package.""" 2 | 3 | 4 | # fixture locations 5 | 6 | # easy version dataset fixtures 7 | CORPUS_EASY_DIR = 'fixtures/corpus-easy' 8 | RESOURCE_EASY_DIR = 'fixtures/resource-easy' 9 | 10 | # raw data fixtures 11 | TEST_COMMENT_PATH = 'fixtures/raw-data/test-comment.jsonl' 12 | TEST_POST_PATH = 'fixtures/raw-data/test-post.jsonl' 13 | TEST_POST_COMMENTS_PATH = 'fixtures/raw-data/test-post-comments.jsonl' 14 | 15 | # gold extraction examples 16 | GOLD_TITLE_DESCRIPTION_EXTRACTIONS_PATH =\ 17 | 'fixtures/gold-extractions/gold-title-description-extractions.jsonl' 18 | -------------------------------------------------------------------------------- /tests/utils.py: -------------------------------------------------------------------------------- 1 | """Utilities for writing tests.""" 2 | 3 | import pkg_resources 4 | 5 | 6 | def copy_pkg_resource_to_disk( 7 | pkg: str, 8 | src: str, 9 | dst: str 10 | ) -> None: 11 | """Copy a package resource to disk. 12 | 13 | Copy the resource from the package ``pkg`` at the path ``src`` to 14 | disk at the path ``dst``. 15 | 16 | Parameters 17 | ---------- 18 | pkg : str 19 | The package holding the resource. 20 | src : str 21 | The source path for the resource in ``pkg``. 22 | dst : str 23 | The destination path for the resource on disk. 24 | 25 | Notes 26 | ----- 27 | This function is primarily useful for testing code that requires 28 | resources to be written on disk, when those test fixtures are 29 | shipped in the package. 30 | """ 31 | with pkg_resources.resource_stream(pkg, src) as src_file,\ 32 | open(dst, 'wb') as dst_file: 33 | dst_file.write(src_file.read()) 34 | --------------------------------------------------------------------------------