├── .dockerignore
├── .gitignore
├── .skiff
├── cloudbuild-deploy.yaml
├── norms.skiff.json
├── scoracle.skiff.json
└── webapp.jsonnet
├── Dockerfile
├── LICENSE
├── MANIFEST.in
├── bin
└── scruples
├── docs
├── annotation-guidelines.md
├── assets
│ └── demo-screenshots
│ │ ├── norms-about.png
│ │ ├── norms-results.png
│ │ ├── scoracle-about.png
│ │ └── scoracle-results.png
└── demos.md
├── hit-definitions
└── resource-labeling
│ ├── data.jsonl
│ └── definition
│ ├── NOTES
│ ├── hitproperties.json
│ ├── hittypeproperties.json
│ └── question.xml.j2
├── readme.md
├── requirements.txt
├── setup.cfg
├── setup.py
├── src
└── scruples
│ ├── __init__.py
│ ├── analysis
│ ├── __init__.py
│ └── traits.py
│ ├── baselines
│ ├── __init__.py
│ ├── bert.py
│ ├── corpus.py
│ ├── labels.py
│ ├── linear.py
│ ├── loss.py
│ ├── metrics.py
│ ├── naivebayes.py
│ ├── resource.py
│ ├── roberta.py
│ ├── style.py
│ ├── train.py
│ ├── trees.py
│ └── utils.py
│ ├── data
│ ├── __init__.py
│ ├── action.py
│ ├── comment.py
│ ├── label_scores.py
│ ├── labels.py
│ ├── post.py
│ ├── post_types.py
│ └── utils.py
│ ├── dataset
│ ├── __init__.py
│ ├── readers.py
│ └── transforms.py
│ ├── demos
│ ├── __init__.py
│ ├── norms
│ │ ├── __init__.py
│ │ ├── app.py
│ │ ├── static
│ │ │ ├── favicon.ico
│ │ │ └── logo-ai2-white-withText-micro.svg
│ │ ├── templates
│ │ │ └── index.html
│ │ └── utils.py
│ └── scoracle
│ │ ├── __init__.py
│ │ ├── app.py
│ │ ├── static
│ │ ├── favicon.ico
│ │ └── logo-ai2-white-withText-micro.svg
│ │ └── templates
│ │ ├── dev.scruples-anecdotes_label-counts.json
│ │ ├── dev.scruples-dilemmas_label-counts.json
│ │ ├── dev.snli_label-counts.json
│ │ └── index.html
│ ├── extraction
│ ├── __init__.py
│ ├── base.py
│ ├── filters.py
│ ├── normalizers.py
│ └── transformers.py
│ ├── scripts
│ ├── __init__.py
│ ├── analyze
│ │ ├── __init__.py
│ │ ├── corpus
│ │ │ ├── __init__.py
│ │ │ ├── extractions.py
│ │ │ ├── human_performance.py
│ │ │ ├── oracle_performance.py
│ │ │ ├── predictions.py
│ │ │ └── statistics.py
│ │ ├── oracle_estimator.py
│ │ └── resource
│ │ │ ├── __init__.py
│ │ │ ├── human_performance.py
│ │ │ ├── latent_traits.py
│ │ │ ├── oracle_performance.py
│ │ │ ├── predictions.py
│ │ │ ├── topics.py
│ │ │ └── verbs.py
│ ├── demo
│ │ ├── __init__.py
│ │ ├── norms.py
│ │ └── scoracle.py
│ ├── evaluate
│ │ ├── __init__.py
│ │ ├── corpus
│ │ │ ├── __init__.py
│ │ │ ├── predict_lm.py
│ │ │ ├── run_shallow.py
│ │ │ └── tune_lm.py
│ │ └── resource
│ │ │ ├── __init__.py
│ │ │ ├── predict_lm.py
│ │ │ ├── run_shallow.py
│ │ │ └── tune_lm.py
│ └── make
│ │ ├── __init__.py
│ │ ├── corpus.py
│ │ └── resource
│ │ ├── __init__.py
│ │ ├── dataset.py
│ │ └── proposals.py
│ ├── settings.py
│ ├── utils.py
│ └── vendor
│ ├── __init__.py
│ └── skopt.py
└── tests
├── __init__.py
├── conftest.py
├── fixtures
├── corpus-easy
│ ├── dev.scruples-corpus.jsonl
│ ├── test.scruples-corpus.jsonl
│ └── train.scruples-corpus.jsonl
├── gold-extractions
│ └── gold-title-description-extractions.jsonl
├── raw-data
│ ├── test-comment.jsonl
│ ├── test-post-comments.jsonl
│ └── test-post.jsonl
└── resource-easy
│ ├── dev.scruples-actions.jsonl
│ ├── test.scruples-actions.jsonl
│ └── train.scruples-actions.jsonl
├── scruples
├── __init__.py
├── analysis
│ ├── __init__.py
│ └── test_traits.py
├── baselines
│ ├── __init__.py
│ ├── test_labels.py
│ ├── test_linear.py
│ ├── test_loss.py
│ ├── test_metrics.py
│ ├── test_naivebayes.py
│ ├── test_style.py
│ ├── test_utils.py
│ └── utils.py
├── data
│ ├── __init__.py
│ ├── test_action.py
│ ├── test_comment.py
│ ├── test_label_scores.py
│ ├── test_labels.py
│ ├── test_post.py
│ ├── test_post_types.py
│ └── test_utils.py
├── dataset
│ ├── __init__.py
│ ├── test_readers.py
│ └── test_transforms.py
├── demos
│ ├── __init__.py
│ ├── norms
│ │ ├── __init__.py
│ │ ├── test_app.py
│ │ └── test_utils.py
│ └── scoracle
│ │ ├── __init__.py
│ │ └── test_app.py
├── extraction
│ ├── __init__.py
│ ├── test_base.py
│ ├── test_filters.py
│ ├── test_normalizers.py
│ └── test_transformers.py
├── test_settings.py
└── test_utils.py
├── settings.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | .hypothesis/
51 | .pytest_cache/
52 |
53 | # Translations
54 | *.mo
55 | *.pot
56 |
57 | # Django stuff:
58 | *.log
59 | local_settings.py
60 | db.sqlite3
61 |
62 | # Flask stuff:
63 | instance/
64 | .webassets-cache
65 |
66 | # Scrapy stuff:
67 | .scrapy
68 |
69 | # Sphinx documentation
70 | docs/_build/
71 |
72 | # PyBuilder
73 | target/
74 |
75 | # Jupyter Notebook
76 | .ipynb_checkpoints
77 |
78 | # IPython
79 | profile_default/
80 | ipython_config.py
81 |
82 | # pyenv
83 | .python-version
84 |
85 | # pipenv
86 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
87 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
88 | # having no cross-platform support, pipenv may install dependencies that don’t work, or not
89 | # install all needed dependencies.
90 | #Pipfile.lock
91 |
92 | # celery beat schedule file
93 | celerybeat-schedule
94 |
95 | # SageMath parsed files
96 | *.sage.py
97 |
98 | # Environments
99 | .env
100 | .venv
101 | env/
102 | venv/
103 | ENV/
104 | env.bak/
105 | venv.bak/
106 |
107 | # Spyder project settings
108 | .spyderproject
109 | .spyproject
110 |
111 | # Rope project settings
112 | .ropeproject
113 |
114 | # mkdocs documentation
115 | /site
116 |
117 | # mypy
118 | .mypy_cache/
119 | .dmypy.json
120 | dmypy.json
121 |
122 | # Pyre type checker
123 | .pyre/
--------------------------------------------------------------------------------
/.skiff/cloudbuild-deploy.yaml:
--------------------------------------------------------------------------------
1 | # This file tells Google Cloud Build how to deploy the application.
2 | # It can be attached to a variety of triggers, the default being whenever
3 | # someone merges changes to the `master` branch.
4 | timeout: '3600s'
5 | steps:
6 | # Pull the latest Docker image, so the build is faster.
7 | - id: 'image.pull'
8 | waitFor: ['-']
9 | name: 'gcr.io/cloud-builders/docker'
10 | entrypoint: '/bin/bash'
11 | args: [
12 | '-c',
13 | 'docker pull gcr.io/$PROJECT_ID/$REPO_NAME:latest || exit 0'
14 | ]
15 | # Build the Docker image
16 | - id: 'image.build'
17 | waitFor: ['image.pull']
18 | name: 'gcr.io/cloud-builders/docker'
19 | args: [
20 | 'build',
21 | '--tag', 'gcr.io/$PROJECT_ID/$REPO_NAME:latest',
22 | '--tag', 'gcr.io/$PROJECT_ID/$REPO_NAME:$COMMIT_SHA',
23 | '--cache-from', 'gcr.io/$PROJECT_ID/$REPO_NAME:latest',
24 | '.',
25 | ]
26 | # Push the Docker image to be deployed
27 | - id: 'image.push'
28 | waitFor: ['image.build']
29 | name: 'gcr.io/cloud-builders/docker'
30 | args: [
31 | 'push',
32 | 'gcr.io/$PROJECT_ID/$REPO_NAME:$COMMIT_SHA',
33 | ]
34 | # Generate the Kubernetes configuration for the norms demo
35 | - id: 'config.norms'
36 | waitFor: ['-']
37 | name: 'gcr.io/ai2-reviz/jsonnet'
38 | args: [
39 | 'eval',
40 | '--yaml-stream',
41 | '--output-file', './norms.webapp.yaml',
42 | '--tla-str', 'app=norms',
43 | '--tla-str', 'env=$_ENV',
44 | '--tla-str', 'image=gcr.io/$PROJECT_ID/$REPO_NAME:$COMMIT_SHA',
45 | '--tla-str', 'sha=$COMMIT_SHA',
46 | '--tla-str', 'cause=Automated Skiff Deploy SHA:$COMMIT_SHA BUILD:$BUILD_ID',
47 | '--tla-str', 'branch=$BRANCH_NAME',
48 | '--tla-str', 'repo=$REPO_NAME',
49 | '--tla-str', 'buildId=$BUILD_ID',
50 | './webapp.jsonnet'
51 | ]
52 | dir: '.skiff'
53 | # Generate the Kubernetes configuration for the scoracle demo
54 | - id: 'config.scoracle'
55 | waitFor: ['-']
56 | name: 'gcr.io/ai2-reviz/jsonnet'
57 | args: [
58 | 'eval',
59 | '--yaml-stream',
60 | '--output-file', './scoracle.webapp.yaml',
61 | '--tla-str', 'app=scoracle',
62 | '--tla-str', 'env=$_ENV',
63 | '--tla-str', 'image=gcr.io/$PROJECT_ID/$REPO_NAME:$COMMIT_SHA',
64 | '--tla-str', 'sha=$COMMIT_SHA',
65 | '--tla-str', 'cause=Automated Skiff Deploy SHA:$COMMIT_SHA BUILD:$BUILD_ID',
66 | '--tla-str', 'branch=$BRANCH_NAME',
67 | '--tla-str', 'repo=$REPO_NAME',
68 | '--tla-str', 'buildId=$BUILD_ID',
69 | './webapp.jsonnet'
70 | ]
71 | dir: '.skiff'
72 | # Deploy the norms demo to Kubernetes
73 | - id: 'deploy.norms'
74 | waitFor: ['image.push', 'config.norms']
75 | name: 'gcr.io/ai2-reviz/rudder'
76 | args: [
77 | 'deploy',
78 | '-f',
79 | 'norms.webapp.yaml'
80 | ]
81 | dir: '.skiff'
82 | # Deploy the scoracle demo to Kubernetes
83 | - id: 'deploy.scoracle'
84 | waitFor: ['image.push', 'config.scoracle']
85 | name: 'gcr.io/ai2-reviz/rudder'
86 | args: [
87 | 'deploy',
88 | '-f',
89 | 'scoracle.webapp.yaml'
90 | ]
91 | dir: '.skiff'
92 | substitutions:
93 | _ENV: staging
94 | images: [
95 | 'gcr.io/$PROJECT_ID/$REPO_NAME:$COMMIT_SHA',
96 | 'gcr.io/$PROJECT_ID/$REPO_NAME:latest'
97 | ]
98 | artifacts:
99 | objects:
100 | location: 'gs://skiff-archive/$REPO_NAME/$_ENV/$BUILD_ID/$COMMIT_SHA'
101 | paths: ['.skiff/norms.webapp.yaml', '.skiff/scoracle.webapp.yaml']
102 |
--------------------------------------------------------------------------------
/.skiff/norms.skiff.json:
--------------------------------------------------------------------------------
1 | {
2 | "appName": "norms",
3 | "contact": "ronanlb",
4 | "team": "aristo",
5 | "args": [ "demo", "norms", "--port", "8000", "--expose" ],
6 | "resources": {
7 | "requests": {
8 | "cpu": "1",
9 | "memory": "16Gi"
10 | }
11 | }
12 | }
13 |
--------------------------------------------------------------------------------
/.skiff/scoracle.skiff.json:
--------------------------------------------------------------------------------
1 | {
2 | "appName": "scoracle",
3 | "contact": "ronanlb",
4 | "team": "aristo",
5 | "args": [ "demo", "scoracle", "--port", "8000", "--expose" ],
6 | "resources": {
7 | "requests": {
8 | "cpu": "1.0",
9 | "memory": "2Gi"
10 | }
11 | }
12 | }
13 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | # Usage: A containerized environment for scruples.
2 |
3 | FROM nvidia/cuda:10.1-devel-ubuntu18.04
4 |
5 |
6 | # Configure language settings.
7 |
8 | ENV LC_ALL=C.UTF-8 \
9 | LANG=C.UTF-8
10 |
11 |
12 | # Create the working directory.
13 |
14 | ENV APP_DIR=/app
15 |
16 | WORKDIR $APP_DIR
17 |
18 |
19 | # Install the OS packages.
20 |
21 | RUN apt-get update --fix-missing \
22 | && DEBIAN_FRONTEND=noninteractive \
23 | apt-get install --no-install-recommends --yes \
24 | build-essential \
25 | curl \
26 | git \
27 | libbz2-dev \
28 | libffi-dev \
29 | liblzma-dev \
30 | libmysqlclient-dev \
31 | libncurses5-dev \
32 | libreadline-dev \
33 | libsqlite3-dev \
34 | libssl-dev \
35 | libxml2-dev \
36 | libxmlsec1-dev \
37 | llvm \
38 | make \
39 | tk-dev \
40 | wget \
41 | xz-utils \
42 | zlib1g-dev
43 |
44 |
45 | # Install Python.
46 |
47 | # Set environment variables for pyenv.
48 |
49 | ENV PYTHON_VERSION="3.7.0" \
50 | PYENV_ROOT="$APP_DIR/.pyenv"
51 |
52 | ENV PATH="$PYENV_ROOT/bin:$PATH"
53 |
54 | # Install pyenv version 1.2.15.
55 |
56 | RUN git clone https://github.com/pyenv/pyenv.git $PYENV_ROOT \
57 | && cd $PYENV_ROOT \
58 | && git checkout b5f69fe58a58f3d2d59511c4630c96aafe5ad2b0
59 |
60 | # Install the python version.
61 |
62 | RUN pyenv install $PYTHON_VERSION
63 |
64 | # Place the new python version on the path.
65 |
66 | ENV PATH="$PYENV_ROOT/versions/$PYTHON_VERSION/bin:$PATH"
67 |
68 |
69 | # Install python package dependencies.
70 |
71 | COPY requirements.txt .
72 |
73 | RUN pip install --upgrade pip \
74 | && pip install --requirement requirements.txt
75 |
76 |
77 | # Download the spacy models.
78 |
79 | RUN python -m spacy download en
80 |
81 |
82 | # Setup the norms demo.
83 |
84 | # Download the models.
85 |
86 | RUN mkdir --parents models/norms
87 |
88 | RUN mkdir models/norms/actions \
89 | && wget \
90 | --quiet \
91 | --directory-prefix models/norms/actions \
92 | https://storage.googleapis.com/ai2-mosaic-public/projects/scruples/v1.0/demos/norms/actions/config.json \
93 | https://storage.googleapis.com/ai2-mosaic-public/projects/scruples/v1.0/demos/norms/actions/pytorch_model.bin
94 |
95 | RUN mkdir models/norms/corpus \
96 | && wget \
97 | --quiet \
98 | --directory-prefix models/norms/corpus \
99 | https://storage.googleapis.com/ai2-mosaic-public/projects/scruples/v1.0/demos/norms/corpus/config.json \
100 | https://storage.googleapis.com/ai2-mosaic-public/projects/scruples/v1.0/demos/norms/corpus/pytorch_model.bin
101 |
102 | # Set the environment variables.
103 |
104 | ENV SCRUPLES_NORMS_ACTIONS_BASELINE="roberta" \
105 | SCRUPLES_NORMS_ACTIONS_MODEL="/app/models/norms/actions" \
106 | SCRUPLES_NORMS_CORPUS_BASELINE="roberta" \
107 | SCRUPLES_NORMS_CORPUS_MODEL="/app/models/norms/corpus" \
108 | SCRUPLES_NORMS_PREDICT_BATCH_SIZE="2" \
109 | SCRUPLES_NORMS_GPU_IDS=""
110 |
111 |
112 | # Copy over the source code
113 |
114 | COPY . .
115 |
116 |
117 | # Install the scruples package.
118 |
119 | RUN pip install .
120 |
121 |
122 | # Define the entrypoint and default command.
123 |
124 | ENTRYPOINT [ "scruples" ]
125 |
126 | CMD [ "demo", "norms", "--port", "8000", "--expose" ]
127 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include readme.md
2 | include docs/*.md
3 | include src/scruples/demos/norms/templates/*
4 | include src/scruples/demos/norms/static/*
5 | include src/scruples/demos/scoracle/templates/*
6 | include src/scruples/demos/scoracle/static/*
7 | include tests/fixtures/*
--------------------------------------------------------------------------------
/bin/scruples:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python
2 |
3 | """The command line interface for scruples."""
4 |
5 | from scruples import scripts
6 |
7 |
8 | if __name__ == '__main__':
9 | scripts.scruples()
10 |
--------------------------------------------------------------------------------
/docs/annotation-guidelines.md:
--------------------------------------------------------------------------------
1 | Annotation Guidelines
2 | =====================
3 | Guidelines for annotating the ground truth for the data.
4 |
5 | These guidelines define how to annotate the ground truth for the reddit
6 | data in order to produce extraction quality evaluations. There are two
7 | kinds of annotations used in evaluating the extractions: comment and
8 | post annotations.
9 |
10 |
11 | Comment Annotations
12 | -------------------
13 | Comment annotations should be written in a JSON Lines file where each
14 | object has the following keys:
15 |
16 |
17 | id
18 | -
19 | the ID attribute of the corresponding comment
20 |
21 | label
22 | -
23 | a gold annotation for the label (one of
"AUTHOR"
,
24 | "OTHER"
, "EVERYBODY"
,
25 | "NOBODY"
, or "INFO"
) expressed by the
26 | comment, or null
if no label is expressed
27 |
28 | implied
29 | -
30 |
true
if the label is implied by the view of the author
31 | and false
if the label is somehow explicitly stated
32 |
33 | spam
34 | -
35 |
true
if the comment is spam, false
36 | otherwise
37 |
38 |
39 |
40 | The possible labels are:
41 |
42 |
43 | AUTHOR
44 | -
45 | The author of the anecdote is in the wrong.
46 |
47 | OTHER
48 | -
49 | The other person in the anecdote is in the wrong.
50 |
51 | EVERYBODY
52 | -
53 | Everyone in the anecdote is in the wrong.
54 |
55 | NOBODY
56 | -
57 | No one in the anecdote is in the wrong.
58 |
59 | INFO
60 | -
61 | More information is required to make a judgment.
62 |
63 |
64 |
65 | If the comment explicitly expresses a label either by its initialism or
66 | some phrase corresponding to the initialism, then use that label for the
67 | comment. Similarly, mark the comment with `implied` as `false` and
68 | `spam` as `false`.
69 |
70 | If the comment expresses multiple labels with no clear winner or is
71 | otherwise ambiguous, mark `label` as `null`, `implied` as `null`, and
72 | `spam` as `true`.
73 |
74 | If the comment expresses no labels explicitly but still has a viewpoint
75 | that clearly expresses one of the labels, then use that label for the
76 | comment. Mark `implied` as `true` and `spam` as `false`.
77 |
78 | Finally, if the comment expresses no label (i.e., none of `AUTHOR`,
79 | `OTHER`, `NOBODY`, `EVERYBODY`, or `INFO`), then mark `label` as `null`,
80 | `implied` as `null`, and `spam` as `true`.
81 |
82 |
83 | Post Annotations
84 | ----------------
85 | Post annotations should be written in a JSON Lines file where each
86 | object has the following keys:
87 |
88 |
89 | id
90 | -
91 | the ID attribute of the corresponding post
92 |
93 | post_type
94 | -
95 | a gold annotation for the post's type
96 |
97 | implied
98 | -
99 |
true
if the post type is not explicitly stated in the
100 | post title.
101 |
102 | spam
103 | -
104 |
true
if the post is spam, false
otherwise
105 |
106 |
107 |
108 | Possible post types are:
109 |
110 |
111 | HISTORICAL
112 | - The author is asking "am I the a**hole?"
113 | HYPOTHETICAL
114 | - The author is asking "would I be the a**hole?"
115 | META
116 | - The post is about the subreddit itself.
117 |
118 |
119 | If the post type is explicitly stated in the post title, then mark
120 | `post_type` as the stated post type, mark `implied` as `false`, and
121 | `spam` as `false`, unless the post type is `META` in which case mark
122 | spam as `true`. Additionally, if the post type is explicitly stated but
123 | clearly wrong (such as using HISTORICAL for a HYPOTHETICAL post), then
124 | use the true post type rather than the stated one.
125 |
126 | If the post type is not explicitly stated in the post title, but
127 | otherwise clear from the post, mark the appropriate post type, mark
128 | `implied` as `true` and `spam` as `false`.
129 |
130 | If the post cannot be categorized into one of the types above, mark the
131 | `post_type` as `null`, `implied` as `null`, and `spam` as `true`.
132 |
133 | If the post is something that should not be present in the dataset (for
134 | example a deleted post), then mark `spam` as `true`.
135 |
--------------------------------------------------------------------------------
/docs/assets/demo-screenshots/norms-about.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/scruples/c79697e97d24f5b43b5ff030a9b16011a32e4a5f/docs/assets/demo-screenshots/norms-about.png
--------------------------------------------------------------------------------
/docs/assets/demo-screenshots/norms-results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/scruples/c79697e97d24f5b43b5ff030a9b16011a32e4a5f/docs/assets/demo-screenshots/norms-results.png
--------------------------------------------------------------------------------
/docs/assets/demo-screenshots/scoracle-about.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/scruples/c79697e97d24f5b43b5ff030a9b16011a32e4a5f/docs/assets/demo-screenshots/scoracle-about.png
--------------------------------------------------------------------------------
/docs/assets/demo-screenshots/scoracle-results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/scruples/c79697e97d24f5b43b5ff030a9b16011a32e4a5f/docs/assets/demo-screenshots/scoracle-results.png
--------------------------------------------------------------------------------
/docs/demos.md:
--------------------------------------------------------------------------------
1 | Demos
2 | =====
3 | Scruples ships with two demos, [scoracle](#scoracle) and [norms](#norms). You
4 | can visit these demos live on the web, or you can run them yourself.
5 |
6 |
7 | Setup
8 | -----
9 | To run either of the demos, first you need to follow the
10 | [Setup](../readme.md#setup) and [Quickstart](../readme.md#quickstart)
11 | documentation in the [readme](../readme.md).
12 |
13 |
14 | Scoracle
15 | --------
16 | To run scoracle, use the `scruples demo scoracle` command. The command is
17 | self-documenting:
18 |
19 | $ scruples demo scoracle --help
20 | Usage: scruples demo scoracle [OPTIONS]
21 |
22 | Serve the scoracle demo.
23 |
24 | Options:
25 | --expose Whether to expose the server to the internet, or run on
26 | localhost.
27 | --port INTEGER The port on which to serve the demo. Defaults to 5000.
28 | --help Show this message and exit.
29 |
30 | So, to run scoracle on localhost at port 5000, execute:
31 |
32 | scruples demo scoracle
33 |
34 | When you visit the site, you should see something like this:
35 |
36 | 
37 |
38 | After submitting a dataset and requesting the BEST performance for some
39 | metrics, you'll see something like this:
40 |
41 | 
42 |
43 |
44 | Norms
45 | -----
46 | Before running the `norms` demo, you'll have to download the config files and
47 | weights:
48 |
49 | - **Scruples Anecdotes demo model**: [config][anecdotes-demo-config]
50 | [weights][anecdotes-demo-weights]
51 | - **Scruples Dilemmas demo model**: [config][dilemmas-demo-config]
52 | [weights][dilemmas-demo-weights]
53 |
54 | Once you've obtained the weights, set the following environment variables:
55 |
56 | SCRUPLES_NORMS_ACTIONS_BASELINE=roberta
57 | SCRUPLES_NORMS_ACTIONS_MODEL=$DILEMMAS_MODEL_PATH
58 | SCRUPLES_NORMS_CORPUS_BASELINE=roberta
59 | SCRUPLES_NORMS_CORPUS_MODEL=$ANECDOTES_MODEL_PATH
60 | SCRUPLES_NORMS_PREDICT_BATCH_SIZE=$BATCH_SIZE
61 | SCRUPLES_NORMS_GPU_IDS=$GPU_IDS
62 |
63 | For each of the models, the path to it should be a directory containing both
64 | the config and the weights file. `$GPU_IDS` should be a comma separated string
65 | of integers corresponding to which GPUs to use. You can set
66 | `SCRUPLES_NORMS_GPU_IDS` to the empty string to use the CPU instead. Start out
67 | with a value of 1 or 2 for `$BATCH_SIZE`, increasing it if you need more
68 | performance and have hardware that can support larger batches.
69 |
70 | To run norms, use the `scruples demo norms` command. The command is
71 | self-documenting:
72 |
73 | $ scruples demo norms --help
74 | Usage: scruples demo norms [OPTIONS]
75 |
76 | Serve the norms demo.
77 |
78 | In order to run this server, you must set the following environment
79 | variables:
80 |
81 | SCRUPLES_NORMS_ACTIONS_BASELINE : The baseline to use for the resource
82 | (actions).
83 | SCRUPLES_NORMS_ACTIONS_MODEL : The path to the saved pretrained
84 | model to use for predicting the actions.
85 | SCRUPLES_NORMS_CORPUS_BASELINE : The baseline to use for the corpus.
86 | SCRUPLES_NORMS_CORPUS_MODEL : The path to the saved pretrained
87 | model to use for predicting the corpus.
88 | SCRUPLES_NORMS_PREDICT_BATCH_SIZE : The batch size to use for
89 | prediction.
90 | SCRUPLES_NORMS_GPU_IDS : A comma separated list of GPU IDs to
91 | use. If none are provided, then the CPU will be used instead.
92 |
93 | Options:
94 | --expose Whether to expose the server to the internet, or run on
95 | localhost.
96 | --port INTEGER The port on which to serve the demo. Defaults to 5050.
97 | --help Show this message and exit.
98 |
99 | So, to run norms on localhost at port 5050, execute:
100 |
101 | scruples demo norms
102 |
103 | When you visit the site, you should see something like this:
104 |
105 | 
106 |
107 | And after submitting a dilemma to the model, you should see results like this:
108 |
109 | 
110 |
111 |
112 | [anecdotes-demo-config]: https://storage.googleapis.com/ai2-mosaic-public/projects/scruples/v1.0/demos/norms/corpus/config.json
113 | [anecdotes-demo-weights]: https://storage.googleapis.com/ai2-mosaic-public/projects/scruples/v1.0/demos/norms/corpus/pytorch_model.bin
114 | [dilemmas-demo-config]: https://storage.googleapis.com/ai2-mosaic-public/projects/scruples/v1.0/demos/norms/actions/config.json
115 | [dilemmas-demo-weights]: https://storage.googleapis.com/ai2-mosaic-public/projects/scruples/v1.0/demos/norms/actions/pytorch_model.bin
116 |
--------------------------------------------------------------------------------
/hit-definitions/resource-labeling/data.jsonl:
--------------------------------------------------------------------------------
1 | {"instances":[{"id":"gVJWzXwANx0nDiJfBNBX8XCXUrmY5iiS","actions":[{"id":"W0zf9dPGrSCeFY2LRbAOknRppO24JkKM","description":"not encouraging my 6yo to have a relationship with her bio dad and 4 half siblings"},{"id":"f3qwv4n2oHAe09qBumna9K2HSu5u2jjS","description":"poking fun at my wife for getting off her diet"}],"label":0},{"id":"TxZG9YEwh575rli6L5P08zRaqMq8WrpH","actions":[{"id":"56s93ij0Biuaro7p0cgkbBeP22FUugjT","description":"using the bathroom together with my boyfriend despite our roommate asking us not to because it makes her uncomfortable"},{"id":"RMTVlCSVNNrPKl19u7D1yICikSPExskv","description":"saying \"who cares\" in a topic about a dead guy with CANCER"}],"label":0},{"id":"Y6SUccZCkoH2R4Pese1nza7HsXlrb32G","actions":[{"id":"XIvbvXQTL9cgp6Sz3TvUlg0Yzvt3r9mE","description":"leaving spoiled food for a food thief to steal"},{"id":"lrbFHh9LucZrDHhqVpoCTMQnUDHwO8bN","description":"filing a charge back against an artist who got my character wrong"}],"label":0},{"id":"wxDLCFx1do97axndP0hCeEkNtq5q2N7Z","actions":[{"id":"dCt8eqx9DORDvKd7drSFQbALuPTlLcIC","description":"saying you can't 'accidentally' sleep with someone"},{"id":"IJol2KEyplmWNVp9ChYr3wMeKZg7vELC","description":"slapping the shit out of a good friend"}],"label":0},{"id":"h2lsDLQBh1QePOgEtdIxnFqcueqasY67","actions":[{"id":"aPMp62Pl181ymGnxuHAssmUDfevDTgiD","description":"being pissed off at my metalhead co-worker"},{"id":"fkyecpgn2jjtlavGVoZAUIMq33rgXGdo","description":"pursuing my friend even after she's said no"}],"label":0},{"id":"SaxKK3kqIz1GhvcbvTUZ9i6sJdAsQcpL","actions":[{"id":"6qpRWTv5cS7gvswu9Sym8fnTW83SDVwE","description":"switching obese people's drink at McDonald's"},{"id":"ltf6CjS90dybzOfpmS1fjMi0C9bovbOG","description":"telling girl I've been seeing that each gender has its own advantages"}],"label":1},{"id":"kyMVRMbHCPLIwXMSenMlQxkDQ8VbejPp","actions":[{"id":"ci7JOi6a1yUZXLHGWZVQEBjEGG3ONSml","description":"rejecting a 'Suicidal' girl"},{"id":"K7pIEoSUkcKA3TGsmTZmHgaIiMN8ZWia","description":"considering my alcoholic husband's (sober) lies a separate \"choice\" from the disease itself"}],"label":0},{"id":"QXPYHAc3xZiQkO5L5VGqPyIK7lxYm5Ak","actions":[{"id":"TR35CUVusNNRR9PatcbIADCYiwwmPvfy","description":"criticizing my girlfriend's excessive Instagram browsing when I myself browse Reddit the same amount"},{"id":"tbbmaVTDLKIjvFYsBn9kotQSVouL8yto","description":"\"wanting to speak to the manager\""}],"label":1},{"id":"aw3pyzwwCI4x30EcOQQ5ARzjEhhOFxSG","actions":[{"id":"zWb9ZYmHdsN37JPKrgAUQrAQZZXqruwP","description":"leaving my boyfriend after he got belligerently drunk and pissed in the floor"},{"id":"6pQlP6mvYvBZP12g3G1EQpXQFn9msuT3","description":"yelling at some kids for not including my disabled daughter in play"}],"label":0},{"id":"AlgukffQBwubC84m3qnXC52XutCwTz8M","actions":[{"id":"PV2bVUdxZwDkePyRyAjMx1vy4XYBIy8q","description":"losing a friend who moved overseas"},{"id":"YyRNoQO6W4Oqo6NJyZ5PmAmDSYVIqKVs","description":"not wanting to continue pregnancy"}],"label":1},{"id":"Cc6kKRjf0D42XpmAvwiKqyiOEfAcBEs8","actions":[{"id":"TqkPxBhuSpR6IHq1S5Ah1VzpY4F52JLD","description":"possibly contributing to a classmate developing severe social anxiety and dropping out of school"},{"id":"IZ8duK1YfDHiClkK6TXrJvOsmLewuGbO","description":"reminding my gf about the meals I paid for"}],"label":1},{"id":"T7ejY85ym1RqsSTXmuyEFOSUsd4g4wIP","actions":[{"id":"94a5I6MFvMfHWDMM0H2jL8UaR0qV9iwM","description":"purposely one upping my broke brother with gifts for his kids"},{"id":"G0khTLzOFpLLfKW3wzjkdJ6fkWFykqMo","description":"snapping at a girl who has a short-ish temper"}],"label":1},{"id":"bYpGXBrdKy8UWLMECeD8JeU2FAwuN6Ui","actions":[{"id":"9LasoY9QCCBLA4zFRlqtQyOaDTFx6faJ","description":"calling the RA on a girl in the bathroom"},{"id":"IPvfLdTjrGbRjweSWCGeOCLAGkGuO0Nu","description":"recognizing Old Enemy as Juror"}],"label":0},{"id":"XRuFXoM3p6mlNa8Fn6ZOdXLRAxf8Tf5c","actions":[{"id":"hWZwVlcawqGxkIpHwJ40GHpVlKRg0S1g","description":"getting mad at my friend for stealing my phone and going into the bathroom with it"},{"id":"i9G2t1gtJwY5ddAzn4T4Pctv6spJb2Su","description":"batting away a lady's hand when she tried to touch my hair"}],"label":1},{"id":"zWuZ0WpFN44oReu4hQ7FN8ILSUzV1ChY","actions":[{"id":"7zeCf23MmqSrphChNgQb6Od9Hwjgj8Q4","description":"not wanting to work a day I asked for off a month ago"},{"id":"eruzElCLr6s9RwN1FrBWHqkYAq1FjEVS","description":"keying a car that was parked in my assigned spot"}],"label":0},{"id":"O2Doeu6A1oVgPjxqTdbPwH6wmdUcxnfm","actions":[{"id":"GYxsQrHIjDPn7a355UZWD0tMoXpRp2hJ","description":"becoming best friends with a friend's crush and now he's having a hard time because he feels I'm a better match for her than he is"},{"id":"SNYOv3kG67s5RR5IXnIqZ2IHq7t32eXv","description":"calling the police/cps on my SIL"}],"label":1},{"id":"Pot8x0BLD9KpbmVsUQOgSe3rKvlgfItX","actions":[{"id":"7KdzoC73NCI3t4TU0OfbfxMB5OspcFbn","description":"asking my neighbor to make sure their kids do not shoot fireworks so close to the car"},{"id":"QSxaCCE2JAUylq18OXRQb4UShKLBgsRR","description":"not going to my best friend and housemates hen party"}],"label":1},{"id":"Nx8f7BpokrNqOwZC4WbpVk2eOUOYPfvc","actions":[{"id":"i0bHuPXMTd3fmlpaTEcWWKmnM77dOpjv","description":"telling my ex-girlfriend's sister that she would never see my kids because she is pornstar"},{"id":"m0WHBUD8zXHYLVJiQsUCUW9xOrSF78wm","description":"being upset with my wife about a death in our family"}],"label":1},{"id":"gq9MOCBomD9o1KdQQhZqOISwAXKvq2SE","actions":[{"id":"QgW7pF3Soh3qo7JpWVclEuuehX9kuXnN","description":"blocking our tenant from the driveway"},{"id":"3fQx0WkNAgNshQwHvqch98AGKEcbpAio","description":"breaking up with my so who revealed they're trans"}],"label":1},{"id":"8gjhpGjfo1hmDiEdInpM06Ui5Bu73X9A","actions":[{"id":"dKwpB2xvbgr35SWdg7ibkQQQUhPmFqJ4","description":"telling my wife I won't try to have kids until she starts trying to take care of herself"},{"id":"9uJTZl1VhVFZKoNlfZ3ZPlNesbb4mlTn","description":"taking food off my friend's plate without asking first"}],"label":0}]}
2 |
--------------------------------------------------------------------------------
/hit-definitions/resource-labeling/definition/NOTES:
--------------------------------------------------------------------------------
1 | This HIT asks people to label instances from the scruples resource.
2 |
--------------------------------------------------------------------------------
/hit-definitions/resource-labeling/definition/hitproperties.json:
--------------------------------------------------------------------------------
1 | {
2 | "MaxAssignments": 10,
3 | "LifetimeInSeconds": 259200
4 | }
5 |
--------------------------------------------------------------------------------
/hit-definitions/resource-labeling/definition/hittypeproperties.json:
--------------------------------------------------------------------------------
1 | {
2 | "AutoApprovalDelayInSeconds": 7200,
3 | "AssignmentDurationInSeconds": 7200,
4 | "Reward": "0.60",
5 | "Title": "What Would Be Worse? (WARNING: adult content / situations)",
6 | "Keywords": "artificial intelligence common sense commonsense ethics",
7 | "Description": "Pick the action that would be worse for someone to do.",
8 | "QualificationRequirements": [
9 | {
10 | "QualificationTypeId": "00000000000000000040",
11 | "Comparator": "GreaterThanOrEqualTo",
12 | "IntegerValues": [
13 | 5000
14 | ],
15 | "RequiredToPreview": false
16 | },
17 | {
18 | "QualificationTypeId": "00000000000000000071",
19 | "Comparator": "In",
20 | "LocaleValues": [{"Country": "US"}, {"Country": "CA"}],
21 | "RequiredToPreview": false
22 | },
23 | {
24 | "QualificationTypeId": "000000000000000000L0",
25 | "Comparator": "GreaterThanOrEqualTo",
26 | "IntegerValues": [
27 | 99
28 | ],
29 | "RequiredToPreview": false
30 | },
31 | {
32 | "QualificationTypeId": "00000000000000000060",
33 | "Comparator": "EqualTo",
34 | "IntegerValues": [1],
35 | "RequiredToPreview": true
36 | }
37 | ]
38 | }
39 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==1.2.0
2 | astunparse==1.6.3
3 | atomicwrites==1.4.1
4 | attrs==19.1.0
5 | autocommand==2.2.1
6 | autograd==1.2
7 | backports.csv==1.0.7
8 | beautifulsoup4==4.11.1
9 | blis==0.2.4
10 | boto3==1.24.82
11 | botocore==1.27.82
12 | cachetools==5.2.0
13 | certifi==2022.9.24
14 | cffi==1.15.1
15 | charset-normalizer==2.1.1
16 | cheroot==8.6.0
17 | CherryPy==18.8.0
18 | Click==7.0
19 | cryptography==38.0.1
20 | cycler==0.11.0
21 | cymem==2.0.6
22 | decorator==5.1.1
23 | dill==0.2.9
24 | feedparser==6.0.10
25 | Flask==1.1.1
26 | flatbuffers==22.9.24
27 | ftfy==5.5.1
28 | future==0.18.2
29 | gast==0.4.0
30 | gevent==21.12.0
31 | google-auth==2.12.0
32 | google-auth-oauthlib==0.4.6
33 | google-pasta==0.2.0
34 | greenlet==1.1.3
35 | grpcio==1.49.1
36 | h5py==3.7.0
37 | idna==3.4
38 | importlib-metadata==4.12.0
39 | importlib-resources==5.9.0
40 | inflect==6.0.0
41 | itsdangerous==1.1.0
42 | jaraco.classes==3.2.3
43 | jaraco.collections==3.5.2
44 | jaraco.context==4.1.2
45 | jaraco.functools==3.5.2
46 | jaraco.text==3.9.1
47 | Jinja2==2.11.3
48 | jmespath==1.0.1
49 | joblib==1.2.0
50 | jsonschema==2.6.0
51 | keras==2.10.0
52 | Keras-Preprocessing==1.1.2
53 | kiwisolver==1.4.4
54 | libclang==14.0.6
55 | lxml==4.9.1
56 | Markdown==3.4.1
57 | MarkupSafe==1.1.1
58 | matplotlib==3.0.3
59 | more-itertools==8.14.0
60 | murmurhash==1.0.8
61 | mysqlclient==2.1.1
62 | networkx==2.3
63 | nltk==3.6.3
64 | numpy==1.21.6
65 | oauthlib==3.2.1
66 | opt-einsum==3.3.0
67 | packaging==21.3
68 | pandas==0.24.2
69 | patsy==0.5.2
70 | Pattern==3.6
71 | pdfminer.six==20220524
72 | Pillow==9.2.0
73 | plac==0.9.6
74 | pluggy==1.0.0
75 | portend==3.1.0
76 | preshed==2.0.1
77 | protobuf==3.19.5
78 | py==1.11.0
79 | pyasn1==0.4.8
80 | pyasn1-modules==0.2.8
81 | pycparser==2.21
82 | pydantic==1.10.2
83 | pyparsing==3.0.9
84 | pytest==4.4.1
85 | python-dateutil==2.8.0
86 | python-docx==0.8.11
87 | pytz==2022.2.1
88 | regex==2018.1.10
89 | requests==2.28.1
90 | requests-oauthlib==1.3.1
91 | rsa==4.9
92 | s3transfer==0.6.0
93 | sacremoses==0.0.53
94 | scikit-learn==0.20.3
95 | scikit-optimize==0.5.2
96 | scipy==1.2.1
97 | seaborn==0.9.0
98 | sentencepiece==0.1.97
99 | sgmllib3k==1.0.0
100 | six==1.16.0
101 | soupsieve==2.3.2.post1
102 | spacy==2.1.3
103 | srsly==1.0.5
104 | statsmodels==0.11.1
105 | tempora==5.0.2
106 | tensorboard==2.10.1
107 | tensorboard-data-server==0.6.1
108 | tensorboard-plugin-wit==1.8.1
109 | tensorboardX==1.6
110 | tensorflow==2.10.0
111 | tensorflow-estimator==2.10.0
112 | tensorflow-io-gcs-filesystem==0.27.0
113 | termcolor==2.0.1
114 | thinc==7.0.8
115 | torch==1.3.1
116 | torchvision==0.4.2
117 | tqdm==4.32.2
118 | transformers==2.1.1
119 | typing_extensions==4.3.0
120 | urllib3==1.26.12
121 | wasabi==0.10.1
122 | wcwidth==0.2.5
123 | Werkzeug==1.0.1
124 | wrapt==1.14.1
125 | xgboost==0.90
126 | zc.lockfile==2.0
127 | zipp==3.8.1
128 | zope.event==4.5.0
129 | zope.interface==5.4.0
130 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [aliases]
2 | test=pytest
3 |
4 |
5 | [tool:pytest]
6 | addopts =
7 | --strict
8 | -p no:warnings
9 | markers =
10 | slow
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | """setup.py file for packaging ``scruples``."""
2 |
3 | from setuptools import setup
4 |
5 |
6 | with open('readme.md', 'r') as readme_file:
7 | readme = readme_file.read()
8 |
9 | with open('requirements.txt', 'r') as requirements_file:
10 | requirements = requirements_file.readlines()
11 |
12 |
13 | setup(
14 | name='scruples',
15 | version='0.3.0',
16 | description='A corpus and code for understanding norms and subjectivity.',
17 | long_description=readme,
18 | url='https://github.com/allenai/scruples',
19 | author='Nicholas Lourie',
20 | author_email='nicholasl@allenai.org',
21 | keywords='scruples social norms artificial intelligence ai'
22 | ' machine learning ml',
23 | classifiers=[
24 | 'Development Status :: 3 - Alpha',
25 | 'Programming Language :: Python :: 3.7',
26 | 'License :: OSI Approved :: Apache Software License',
27 | 'Intended Audience :: Science/Research',
28 | 'Topic :: Scientific/Engineering :: Artificial Intelligence'
29 | ],
30 | license='Apache',
31 | packages=[
32 | 'scruples',
33 | 'scruples.analysis',
34 | 'scruples.baselines',
35 | 'scruples.data',
36 | 'scruples.dataset',
37 | 'scruples.demos',
38 | 'scruples.demos.norms',
39 | 'scruples.demos.scoracle',
40 | 'scruples.extraction',
41 | 'scruples.scripts',
42 | 'scruples.scripts.analyze',
43 | 'scruples.scripts.analyze.corpus',
44 | 'scruples.scripts.analyze.resource',
45 | 'scruples.scripts.demo',
46 | 'scruples.scripts.evaluate',
47 | 'scruples.scripts.evaluate.corpus',
48 | 'scruples.scripts.evaluate.resource',
49 | 'scruples.scripts.make',
50 | 'scruples.scripts.make.resource',
51 | 'scruples.vendor'
52 | ],
53 | package_dir={'': 'src'},
54 | scripts=['bin/scruples'],
55 | install_requires=[
56 | 'Click >= 7.0',
57 | 'Flask >= 1.1.1',
58 | 'Pattern >= 3.6',
59 | 'attrs >= 19.1.0',
60 | 'autograd >= 1.2',
61 | 'dill >= 0.2.9',
62 | 'ftfy >= 5.5.1',
63 | 'gevent >= 1.4.0',
64 | 'itsdangerous < 2.0.0',
65 | 'Jinja2 < 3.0.0',
66 | 'MarkupSafe < 2.0.0',
67 | 'matplotlib >= 3.0.3',
68 | 'networkx >= 2.3',
69 | 'numpy >= 1.16.2',
70 | 'pandas >= 0.24.2',
71 | 'regex >= 2018.1.10',
72 | 'scikit-learn >= 0.20.3, < 0.21',
73 | 'scikit-optimize >= 0.5.2',
74 | 'scipy >= 1.2.1, < 1.3',
75 | 'seaborn >= 0.9.0',
76 | 'spacy >= 2.1.3, < 2.2.0',
77 | 'tensorboard >= 1.15.0',
78 | 'tensorboardX >= 1.6',
79 | 'tensorflow >= 1.15.2',
80 | 'tensorflow-estimator >= 1.15.1',
81 | 'torch >= 1.3.1',
82 | 'torchvision >= 0.4.2',
83 | 'transformers >= 2.1.1',
84 | 'Werkzeug < 2.0.0',
85 | 'xgboost >= 0.82'
86 | ],
87 | setup_requires=[
88 | 'pytest-runner'
89 | ],
90 | tests_require=[
91 | 'pytest'
92 | ],
93 | include_package_data=True,
94 | python_requires='>= 3.7',
95 | zip_safe=False)
96 |
--------------------------------------------------------------------------------
/src/scruples/__init__.py:
--------------------------------------------------------------------------------
1 | """Build a dataset for detecting social norm violations."""
2 |
3 | from . import (
4 | analysis,
5 | baselines,
6 | data,
7 | dataset,
8 | extraction,
9 | scripts,
10 | settings,
11 | utils,
12 | vendor)
13 |
--------------------------------------------------------------------------------
/src/scruples/analysis/__init__.py:
--------------------------------------------------------------------------------
1 | """Tools for conducting analyses."""
2 |
3 | from . import traits
4 |
--------------------------------------------------------------------------------
/src/scruples/baselines/__init__.py:
--------------------------------------------------------------------------------
1 | """Baselines for scruples."""
2 |
3 | from . import (
4 | bert,
5 | labels,
6 | linear,
7 | metrics,
8 | naivebayes,
9 | roberta,
10 | style,
11 | train,
12 | trees,
13 | utils)
14 |
--------------------------------------------------------------------------------
/src/scruples/baselines/bert.py:
--------------------------------------------------------------------------------
1 | """BERT baselines."""
2 |
3 | from transformers import (
4 | BertForMultipleChoice,
5 | BertForSequenceClassification,
6 | BertTokenizer)
7 | import skopt
8 | import torch
9 |
10 | from ..data.labels import Label
11 | from ..dataset.transforms import (
12 | BertTransform,
13 | Compose,
14 | Map)
15 |
16 |
17 | # the BERT sequence classification baseline
18 |
19 | BERTClassifier = BertForSequenceClassification.from_pretrained
20 | """Predict fixed classes with a fine-tuned BERT model."""
21 |
22 |
23 | BERT_CLASSIFIER_CONFIG = {
24 | 'model': {
25 | # N.B. pretrained_model_name_or_path for the model must be the
26 | # same as pretrained_bert for the transform
27 | 'pretrained_model_name_or_path': 'bert-large-uncased',
28 | 'num_labels': len(Label)
29 | },
30 | 'transform': {
31 | # N.B. pretrained_bert for the transform must be the same as
32 | # pretrained_model_name_or_path for the model
33 | 'pretrained_bert': 'bert-large-uncased',
34 | 'max_sequence_length': 512,
35 | 'truncation_strategy_title': 'beginning',
36 | 'truncation_strategy_text': 'beginning'
37 | }
38 | }
39 | """Configuration for ``BERTClassifier``."""
40 |
41 |
42 | BERT_CLASSIFIER_HYPER_PARAM_SPACE = [
43 | skopt.space.Real(
44 | low=1e-8,
45 | high=1e-2,
46 | prior='log-uniform',
47 | name='lr'),
48 | skopt.space.Real(
49 | low=1e-5,
50 | high=1e0,
51 | prior='log-uniform',
52 | name='weight_decay'),
53 | skopt.space.Real(
54 | low=0.0,
55 | high=1.0,
56 | prior='uniform',
57 | name='warmup_proportion'),
58 | skopt.space.Integer(
59 | low=1,
60 | high=10,
61 | name='n_epochs'),
62 | skopt.space.Integer(
63 | low=3,
64 | high=10,
65 | name='log_train_batch_size')
66 | ]
67 | """The hyper-param search space for ``BERTClassifier``."""
68 |
69 |
70 | BERT_CLASSIFIER_TRANSFORM = (
71 | lambda
72 | pretrained_bert,
73 | max_sequence_length,
74 | truncation_strategy_title,
75 | truncation_strategy_text:
76 | Compose([
77 | BertTransform(
78 | tokenizer=BertTokenizer.from_pretrained(
79 | pretrained_bert,
80 | do_lower_case=pretrained_bert.endswith('-uncased')),
81 | max_sequence_length=max_sequence_length,
82 | truncation_strategy=(
83 | truncation_strategy_title,
84 | truncation_strategy_text
85 | )),
86 | lambda d: {
87 | 'input_ids': torch.tensor(d['input_ids']),
88 | 'attention_mask': torch.tensor(d['input_mask']),
89 | 'token_type_ids': torch.tensor(d['segment_ids'])
90 | }
91 | ])
92 | )
93 | """The factory to create data transforms for ``BERTClassifier``."""
94 |
95 |
96 | # the BERT ranking baseline
97 |
98 | BERTRanker = BertForMultipleChoice.from_pretrained
99 | """Rank choices with a softmax over a fine-tuned BERT model."""
100 |
101 |
102 | BERT_RANKER_CONFIG = {
103 | 'model': {
104 | # N.B. pretrained_model_name_or_path for the model must be the
105 | # same as pretrained_bert for the transform
106 | 'pretrained_model_name_or_path': 'bert-large-uncased'
107 | },
108 | 'transform': {
109 | # N.B. pretrained_bert for the transform must be the same as
110 | # pretrained_model_name_or_path for the model
111 | 'pretrained_bert': 'bert-large-uncased',
112 | 'max_sequence_length': 92
113 | }
114 | }
115 | """Configuration for ``BERTRanker``."""
116 |
117 |
118 | BERT_RANKER_HYPER_PARAM_SPACE = [
119 | skopt.space.Real(
120 | low=1e-8,
121 | high=1e-2,
122 | prior='log-uniform',
123 | name='lr'),
124 | skopt.space.Real(
125 | low=1e-5,
126 | high=1e0,
127 | prior='log-uniform',
128 | name='weight_decay'),
129 | skopt.space.Real(
130 | low=0.0,
131 | high=1.0,
132 | prior='uniform',
133 | name='warmup_proportion'),
134 | skopt.space.Integer(
135 | low=1,
136 | high=25,
137 | name='n_epochs'),
138 | skopt.space.Integer(
139 | low=3,
140 | high=10,
141 | name='log_train_batch_size')
142 | ]
143 | """The hyper-param seach space for ``BERTRanker``."""
144 |
145 |
146 | BERT_RANKER_TRANSFORM = (
147 | lambda
148 | pretrained_bert,
149 | max_sequence_length:
150 | Compose([
151 | # wrap each action in a tuple for passing it to BertTransform
152 | lambda actions: tuple((action, None) for action in actions),
153 | # map BertTransform across all the action choices
154 | Map(
155 | transform=BertTransform(
156 | tokenizer=BertTokenizer.from_pretrained(
157 | pretrained_bert,
158 | do_lower_case=pretrained_bert.endswith('-uncased')),
159 | max_sequence_length=max_sequence_length,
160 | truncation_strategy=('beginning', 'beginning'))),
161 | # collect the action choices and stack their tensors so the
162 | # choices can be their own dimension of the batch
163 | lambda ds: {
164 | 'input_ids': torch.stack([
165 | torch.tensor(d['input_ids'])
166 | for d in ds
167 | ], dim=0),
168 | 'attention_mask': torch.stack([
169 | torch.tensor(d['input_mask'])
170 | for d in ds
171 | ], dim=0),
172 | 'token_type_ids': torch.stack([
173 | torch.tensor(d['segment_ids'])
174 | for d in ds
175 | ], dim=0)
176 | }
177 | ])
178 | )
179 | """The factory to create data transforms for ``BERTRanker``."""
180 |
--------------------------------------------------------------------------------
/src/scruples/baselines/corpus.py:
--------------------------------------------------------------------------------
1 | """Baseline configuration for the scruples corpus."""
2 |
3 | from . import (
4 | bert,
5 | labels,
6 | linear,
7 | naivebayes,
8 | roberta,
9 | style,
10 | trees)
11 |
12 |
13 | SHALLOW_BASELINES = {
14 | 'prior': (
15 | labels.PriorBaseline,
16 | labels.PRIOR_HYPER_PARAMETERS
17 | ),
18 | 'stratified': (
19 | labels.StratifiedBaseline,
20 | labels.STRATIFIED_HYPER_PARAMETERS
21 | ),
22 | 'logisticregression': (
23 | linear.LogisticRegressionBaseline,
24 | linear.LOGISTIC_REGRESSION_HYPER_PARAMETERS
25 | ),
26 | 'bernoullinb': (
27 | naivebayes.BernoulliNBBaseline,
28 | naivebayes.BERNOULLINB_HYPER_PARAMETERS
29 | ),
30 | 'multinomialnb': (
31 | naivebayes.MultinomialNBBaseline,
32 | naivebayes.MULTINOMIALNB_HYPER_PARAMETERS
33 | ),
34 | 'complementnb': (
35 | naivebayes.ComplementNBBaseline,
36 | naivebayes.COMPLEMENTNB_HYPER_PARAMETERS
37 | ),
38 | 'stylistic': (
39 | style.StylisticXGBoostBaseline,
40 | style.STYLISTICXGBOOST_HYPER_PARAMETERS
41 | ),
42 | 'randomforest': (
43 | trees.RandomForestBaseline,
44 | trees.RANDOM_FOREST_HYPER_PARAMETERS
45 | )
46 | }
47 | """Shallow baseline models for the scruples corpus."""
48 |
49 |
50 | FINE_TUNE_LM_BASELINES = {
51 | 'bert': (
52 | bert.BERTClassifier,
53 | bert.BERT_CLASSIFIER_CONFIG,
54 | bert.BERT_CLASSIFIER_HYPER_PARAM_SPACE,
55 | bert.BERT_CLASSIFIER_TRANSFORM
56 | ),
57 | 'roberta': (
58 | roberta.RoBERTaClassifier,
59 | roberta.ROBERTA_CLASSIFIER_CONFIG,
60 | roberta.ROBERTA_CLASSIFIER_HYPER_PARAM_SPACE,
61 | roberta.ROBERTA_CLASSIFIER_TRANSFORM
62 | )
63 | }
64 | """Fine-tuned language model baselines for the scruples corpus."""
65 |
--------------------------------------------------------------------------------
/src/scruples/baselines/labels.py:
--------------------------------------------------------------------------------
1 | """Label only baselines."""
2 |
3 | from sklearn.dummy import DummyClassifier
4 | from sklearn.pipeline import Pipeline
5 |
6 |
7 | # the class prior baseline
8 |
9 | PriorBaseline = Pipeline([
10 | ('classifier', DummyClassifier(strategy='prior'))
11 | ])
12 | """Predict using the label distribution."""
13 |
14 | PRIOR_HYPER_PARAMETERS = {}
15 | """The hyper-param search space for ``PriorBaseline``."""
16 |
17 |
18 | # the stratified sampling baseline
19 |
20 | StratifiedBaseline = Pipeline([
21 | ('classifier', DummyClassifier(strategy='stratified'))
22 | ])
23 | """Predict by sampling a class according to its probability."""
24 |
25 | STRATIFIED_HYPER_PARAMETERS = {}
26 | """The hyper-param search space for ``StratifiedBaseline``."""
27 |
--------------------------------------------------------------------------------
/src/scruples/baselines/linear.py:
--------------------------------------------------------------------------------
1 | """Linear model baselines."""
2 |
3 | from sklearn.feature_extraction.text import (
4 | CountVectorizer,
5 | TfidfTransformer)
6 | from sklearn.linear_model import LogisticRegression
7 | from sklearn.pipeline import Pipeline
8 | from sklearn.preprocessing import FunctionTransformer, MaxAbsScaler
9 |
10 | from . import utils
11 |
12 |
13 | # the logistic regression baseline
14 |
15 | LogisticRegressionBaseline = Pipeline([
16 | (
17 | 'concatenator',
18 | FunctionTransformer(func=utils.concat_title_and_text, validate=False)
19 | ),
20 | (
21 | 'vectorizer',
22 | CountVectorizer(
23 | input='content',
24 | encoding='utf-8',
25 | decode_error='strict',
26 | preprocessor=None,
27 | tokenizer=None,
28 | token_pattern=r'(?u)\b\w\w+\b',
29 | max_features=None,
30 | vocabulary=None)
31 | ),
32 | (
33 | 'tfidf',
34 | TfidfTransformer(smooth_idf=True)
35 | ),
36 | (
37 | 'scaler',
38 | MaxAbsScaler()
39 | ),
40 | (
41 | 'classifier',
42 | LogisticRegression(
43 | penalty='l2',
44 | dual=False,
45 | tol=1e-4,
46 | intercept_scaling=1.,
47 | solver='lbfgs',
48 | max_iter=100,
49 | warm_start=True)
50 | )
51 | ])
52 | """Predict using logistic regression on bag-of-ngrams features."""
53 |
54 | LOGISTIC_REGRESSION_HYPER_PARAMETERS = {
55 | 'vectorizer__strip_accents': ['ascii', 'unicode', None],
56 | 'vectorizer__lowercase': [True, False],
57 | 'vectorizer__stop_words': ['english', None],
58 | 'vectorizer__ngram_range': [
59 | (lo, hi)
60 | for lo in range(1, 2)
61 | for hi in range(lo, lo + 5)
62 | ],
63 | 'vectorizer__analyzer': ['word', 'char', 'char_wb'],
64 | 'vectorizer__max_df': (0.75, 1., 'uniform'),
65 | 'vectorizer__min_df': (0., 0.25, 'uniform'),
66 | 'vectorizer__binary': [True, False],
67 | 'tfidf__norm': ['l1', 'l2', None],
68 | 'tfidf__use_idf': [True, False],
69 | 'tfidf__sublinear_tf': [True, False],
70 | 'classifier__C': (1e-6, 1e2, 'log-uniform'),
71 | 'classifier__fit_intercept': [True, False],
72 | 'classifier__class_weight': ['balanced', None],
73 | 'classifier__multi_class': ['ovr', 'multinomial']
74 | }
75 | """The hyper-param search space for ``LogisticRegressionBaseline``."""
76 |
77 |
78 | # the logistic ranker baseline
79 |
80 | LogisticRankerBaseline = Pipeline([
81 | (
82 | 'featurizer',
83 | utils.ResourceTransformer(
84 | transformer=Pipeline([
85 | (
86 | 'vectorizer',
87 | CountVectorizer(
88 | input='content',
89 | encoding='utf-8',
90 | decode_error='featurizer',
91 | preprocessor=None,
92 | tokenizer=None,
93 | token_pattern=r'(?u)\b\w\w+\b',
94 | max_features=None,
95 | vocabulary=None)
96 | ),
97 | (
98 | 'tfidf',
99 | TfidfTransformer(smooth_idf=True)
100 | ),
101 | (
102 | 'scaler',
103 | MaxAbsScaler()
104 | )
105 | ]))
106 | ),
107 | (
108 | 'classifier',
109 | LogisticRegression(
110 | penalty='l2',
111 | dual=False,
112 | tol=1e-4,
113 | fit_intercept=False,
114 | intercept_scaling=1.,
115 | solver='lbfgs',
116 | max_iter=100,
117 | warm_start=True)
118 | )
119 | ])
120 | """Rank using logistic regression on bag-of-ngrams features."""
121 |
122 | LOGISTIC_RANKER_HYPER_PARAMETERS = {
123 | 'featurizer__transformer__vectorizer__strip_accents': ['ascii', 'unicode', None],
124 | 'featurizer__transformer__vectorizer__lowercase': [True, False],
125 | 'featurizer__transformer__vectorizer__stop_words': ['english', None],
126 | 'featurizer__transformer__vectorizer__ngram_range': [
127 | (lo, hi)
128 | for lo in range(1, 2)
129 | for hi in range(lo, lo + 5)
130 | ],
131 | 'featurizer__transformer__vectorizer__analyzer': ['word', 'char', 'char_wb'],
132 | 'featurizer__transformer__vectorizer__max_df': (0.90, 1., 'uniform'),
133 | 'featurizer__transformer__vectorizer__min_df': (0., 0.10, 'uniform'),
134 | 'featurizer__transformer__vectorizer__binary': [True, False],
135 | 'featurizer__transformer__tfidf__norm': ['l1', 'l2', None],
136 | 'featurizer__transformer__tfidf__use_idf': [True, False],
137 | 'featurizer__transformer__tfidf__sublinear_tf': [True, False],
138 | 'classifier__C': (1e-6, 1e2, 'log-uniform'),
139 | 'classifier__class_weight': ['balanced', None],
140 | }
141 | """The hyper-param search space for ``LogisticRankerBaseline``."""
142 |
--------------------------------------------------------------------------------
/src/scruples/baselines/loss.py:
--------------------------------------------------------------------------------
1 | """Custom losses for baselines."""
2 |
3 | import torch as th
4 | from torch.nn import functional as F
5 |
6 |
7 | class SoftCrossEntropyLoss(th.nn.Module):
8 | """Cross-entropy with soft reference labels."""
9 |
10 | # N.B. This loss can be used in two ways. First, labels can be
11 | # averaged and then passed in as the targets, and second, the raw
12 | # counts of each label can be passed in as the target. The first
13 | # approach leverages the soft labels, but ignores relative certainty
14 | # of one instances labeling versus another. The second approach
15 | # leverages all of the annotation information available. Both are
16 | # equivalent if each instance has the same number of labels.
17 | def forward(self, input, target):
18 | return - th.mean(
19 | th.sum(target * F.log_softmax(input, dim=-1), dim=-1))
20 |
21 |
22 | class DirichletMultinomialLoss(th.nn.Module):
23 | """Negative log-likelihood for a dirichlet-multinomial."""
24 |
25 | # N.B. note that this function computes the likelihood of the
26 | # observed labels, and not the likelihood of the sufficient
27 | # statistic derived from them (i.e., the counts of each label). We
28 | # only need the sufficient statistic however to compute this
29 | # likelihood, and both lead to the same MLE.
30 | def forward(self, inputs, targets):
31 | inputs = th.exp(inputs)
32 | return - th.mean(
33 | th.lgamma(th.sum(inputs, dim=-1))
34 | + th.sum(th.lgamma(inputs + targets), dim=-1)
35 | - th.lgamma(th.sum(inputs + targets, dim=-1))
36 | - th.sum(th.lgamma(inputs), dim=-1))
37 |
--------------------------------------------------------------------------------
/src/scruples/baselines/metrics.py:
--------------------------------------------------------------------------------
1 | """Metrics for assessing baseline models."""
2 |
3 | from sklearn import metrics
4 |
5 |
6 | METRICS = {
7 | 'accuracy': (
8 | 'accuracy',
9 | metrics.accuracy_score,
10 | {
11 | 'greater_is_better': True,
12 | 'needs_proba': False
13 | }
14 | ),
15 | 'balanced_accuracy': (
16 | 'balanced accuracy',
17 | metrics.balanced_accuracy_score,
18 | {
19 | 'greater_is_better': True,
20 | 'needs_proba': False
21 | }
22 | ),
23 | 'precision_micro': (
24 | 'precision (micro)',
25 | lambda y_true, y_pred: metrics.precision_score(
26 | y_true=y_true,
27 | y_pred=y_pred,
28 | average='micro'),
29 | {
30 | 'greater_is_better': True,
31 | 'needs_proba': False
32 | }
33 | ),
34 | 'recall_micro': (
35 | 'recall (micro)',
36 | lambda y_true, y_pred: metrics.recall_score(
37 | y_true=y_true,
38 | y_pred=y_pred,
39 | average='micro'),
40 | {
41 | 'greater_is_better': True,
42 | 'needs_proba': False
43 | }
44 | ),
45 | 'f1_micro': (
46 | 'f1 (micro)',
47 | lambda y_true, y_pred: metrics.f1_score(
48 | y_true=y_true,
49 | y_pred=y_pred,
50 | average='micro'),
51 | {
52 | 'greater_is_better': True,
53 | 'needs_proba': False
54 | }
55 | ),
56 | 'precision_macro': (
57 | 'precision (macro)',
58 | lambda y_true, y_pred: metrics.precision_score(
59 | y_true=y_true,
60 | y_pred=y_pred,
61 | average='macro'),
62 | {
63 | 'greater_is_better': True,
64 | 'needs_proba': False
65 | }
66 | ),
67 | 'recall_macro': (
68 | 'recall (macro)',
69 | lambda y_true, y_pred: metrics.recall_score(
70 | y_true=y_true,
71 | y_pred=y_pred,
72 | average='macro'),
73 | {
74 | 'greater_is_better': True,
75 | 'needs_proba': False
76 | }
77 | ),
78 | 'f1_macro': (
79 | 'f1 (macro)',
80 | lambda y_true, y_pred: metrics.f1_score(
81 | y_true=y_true,
82 | y_pred=y_pred,
83 | average='macro'),
84 | {
85 | 'greater_is_better': True,
86 | 'needs_proba': False
87 | }
88 | ),
89 | 'f1_weighted': (
90 | 'f1 (weighted)',
91 | lambda y_true, y_pred: metrics.f1_score(
92 | y_true=y_true,
93 | y_pred=y_pred,
94 | average='weighted'),
95 | {
96 | 'greater_is_better': True,
97 | 'needs_proba': False
98 | }
99 | ),
100 | 'log_loss': (
101 | 'log loss',
102 | lambda y_true, y_pred: metrics.log_loss(
103 | y_true=y_true,
104 | y_pred=y_pred,
105 | eps=1e-9),
106 | {
107 | 'greater_is_better': False,
108 | 'needs_proba': True
109 | }
110 | ),
111 | 'matthews_corrcoef': (
112 | 'matthews correlation coefficient',
113 | metrics.matthews_corrcoef,
114 | {
115 | 'greater_is_better': True,
116 | 'needs_proba': False
117 | }
118 | )
119 | # N.B., do not include a key for "xentropy" or "calibrated_xentropy"
120 | # in this dictionary. Those keys are reserved for the cross-entropy
121 | # between the predicted probabilities and the dataset's label
122 | # scores, which is computed in the scruples.scripts.analyze.*.predictions
123 | # scripts.
124 | }
125 | """A dictionary defining the important metrics to assess baselines.
126 |
127 | The dictionary maps metric names to ``(name, metric, scorer_kwargs)``
128 | tuples. ``name`` is the name of the metric, while ``metric`` is a
129 | function for computing it, and ``scorer_kwargs`` is a dictionary
130 | containing two keys: ``"greater_is_better"``, a boolean defining whether
131 | or not higher values are better, and ``"needs_proba"``, a boolean
132 | defining whether to pass the predicted labels or the predicted
133 | probabilities.
134 | """
135 |
--------------------------------------------------------------------------------
/src/scruples/baselines/naivebayes.py:
--------------------------------------------------------------------------------
1 | """Naive bayes baselines."""
2 |
3 | from sklearn.feature_extraction.text import CountVectorizer
4 | from sklearn.naive_bayes import (
5 | BernoulliNB,
6 | MultinomialNB,
7 | ComplementNB)
8 | from sklearn.pipeline import Pipeline
9 | from sklearn.preprocessing import FunctionTransformer
10 |
11 | from . import utils
12 |
13 |
14 | # the bernoulli naive bayes baseline
15 |
16 | BernoulliNBBaseline = Pipeline([
17 | (
18 | 'concatenator',
19 | FunctionTransformer(func=utils.concat_title_and_text, validate=False)
20 | ),
21 | (
22 | 'vectorizer',
23 | CountVectorizer(
24 | input='content',
25 | encoding='utf-8',
26 | decode_error='strict',
27 | preprocessor=None,
28 | tokenizer=None,
29 | token_pattern=r'(?u)\b\w\w+\b',
30 | max_features=None,
31 | vocabulary=None,
32 | binary=True)
33 | ),
34 | (
35 | 'classifier',
36 | BernoulliNB(binarize=None, fit_prior=True)
37 | )
38 | ])
39 | """Predict using bernoulli naive bayes on bag-of-ngrams features."""
40 |
41 | BERNOULLINB_HYPER_PARAMETERS = {
42 | 'vectorizer__strip_accents': ['ascii', 'unicode', None],
43 | 'vectorizer__lowercase': [True, False],
44 | 'vectorizer__stop_words': ['english', None],
45 | 'vectorizer__ngram_range': [
46 | (lo, hi)
47 | for lo in range(1, 2)
48 | for hi in range(lo, lo + 5)
49 | ],
50 | 'vectorizer__analyzer': ['word', 'char', 'char_wb'],
51 | 'vectorizer__max_df': (0.75, 1., 'uniform'),
52 | 'vectorizer__min_df': (0., 0.25, 'uniform'),
53 | 'classifier__alpha': (0., 5., 'uniform')
54 | }
55 | """The hyper-param search space for ``BeroulliNBBaseline``."""
56 |
57 |
58 | # the multinomial naive bayes baseline
59 |
60 | MultinomialNBBaseline = Pipeline([
61 | (
62 | 'concatenator',
63 | FunctionTransformer(func=utils.concat_title_and_text, validate=False)
64 | ),
65 | (
66 | 'vectorizer',
67 | CountVectorizer(
68 | input='content',
69 | encoding='utf-8',
70 | decode_error='strict',
71 | preprocessor=None,
72 | tokenizer=None,
73 | token_pattern=r'(?u)\b\w\w+\b',
74 | max_features=None,
75 | vocabulary=None,
76 | binary=False)
77 | ),
78 | (
79 | 'classifier',
80 | MultinomialNB(fit_prior=True)
81 | )
82 | ])
83 | """Predict using multinomial naive bayes on bag-of-ngrams features."""
84 |
85 | MULTINOMIALNB_HYPER_PARAMETERS = {
86 | 'vectorizer__strip_accents': ['ascii', 'unicode', None],
87 | 'vectorizer__lowercase': [True, False],
88 | 'vectorizer__stop_words': ['english', None],
89 | 'vectorizer__ngram_range': [
90 | (lo, hi)
91 | for lo in range(1, 2)
92 | for hi in range(lo, lo + 5)
93 | ],
94 | 'vectorizer__analyzer': ['word', 'char', 'char_wb'],
95 | 'vectorizer__max_df': (0.75, 1., 'uniform'),
96 | 'vectorizer__min_df': (0., 0.25, 'uniform'),
97 | 'classifier__alpha': (0., 5., 'uniform')
98 | }
99 | """The hyper-param search space for ``MultinomialNBBaseline``."""
100 |
101 |
102 | # the complement naive bayes baseline
103 |
104 | ComplementNBBaseline = Pipeline([
105 | (
106 | 'concatenator',
107 | FunctionTransformer(func=utils.concat_title_and_text, validate=False)
108 | ),
109 | (
110 | 'vectorizer',
111 | CountVectorizer(
112 | input='content',
113 | encoding='utf-8',
114 | decode_error='strict',
115 | preprocessor=None,
116 | tokenizer=None,
117 | token_pattern=r'(?u)\b\w\w+\b',
118 | max_features=None,
119 | vocabulary=None,
120 | binary=False)
121 | ),
122 | (
123 | 'classifier',
124 | ComplementNB(fit_prior=True)
125 | )
126 | ])
127 | """Predict using complement naive bayes on bag-of-ngrams features."""
128 |
129 | COMPLEMENTNB_HYPER_PARAMETERS = {
130 | 'vectorizer__strip_accents': ['ascii', 'unicode', None],
131 | 'vectorizer__lowercase': [True, False],
132 | 'vectorizer__stop_words': ['english', None],
133 | 'vectorizer__ngram_range': [
134 | (lo, hi)
135 | for lo in range(1, 2)
136 | for hi in range(lo, lo + 5)
137 | ],
138 | 'vectorizer__analyzer': ['word', 'char', 'char_wb'],
139 | 'vectorizer__max_df': (0.75, 1., 'uniform'),
140 | 'vectorizer__min_df': (0., 0.25, 'uniform'),
141 | 'classifier__alpha': (0., 5., 'uniform'),
142 | 'classifier__norm': [True, False]
143 | }
144 | """The hyper-param search space for ``ComplementNBBaseline``."""
145 |
--------------------------------------------------------------------------------
/src/scruples/baselines/resource.py:
--------------------------------------------------------------------------------
1 | """Baseline configuration for the scruples resource."""
2 |
3 | from . import (
4 | bert,
5 | labels,
6 | linear,
7 | roberta,
8 | style)
9 |
10 |
11 | SHALLOW_BASELINES = {
12 | 'prior': (
13 | labels.PriorBaseline,
14 | labels.PRIOR_HYPER_PARAMETERS
15 | ),
16 | 'stratified': (
17 | labels.StratifiedBaseline,
18 | labels.STRATIFIED_HYPER_PARAMETERS
19 | ),
20 | 'fewestwords': (
21 | style.FewestWordsBaseline,
22 | style.FEWEST_WORDS_HYPER_PARAMETERS
23 | ),
24 | 'mostwords': (
25 | style.MostWordsBaseline,
26 | style.MOST_WORDS_HYPER_PARAMETERS
27 | ),
28 | 'fewestcharacters': (
29 | style.FewestCharactersBaseline,
30 | style.FEWEST_CHARACTERS_HYPER_PARAMETERS
31 | ),
32 | 'mostcharacters': (
33 | style.MostCharactersBaseline,
34 | style.MOST_CHARACTERS_HYPER_PARAMETERS
35 | ),
36 | 'logisticranker': (
37 | linear.LogisticRankerBaseline,
38 | linear.LOGISTIC_RANKER_HYPER_PARAMETERS
39 | ),
40 | 'stylistic': (
41 | style.StyleRankerBaseline,
42 | style.STYLE_RANKER_HYPER_PARAMETERS
43 | )
44 | }
45 | """Shallow baseline models for the scruples resource."""
46 |
47 |
48 | FINE_TUNE_LM_BASELINES = {
49 | 'bert': (
50 | bert.BERTRanker,
51 | bert.BERT_RANKER_CONFIG,
52 | bert.BERT_RANKER_HYPER_PARAM_SPACE,
53 | bert.BERT_RANKER_TRANSFORM
54 | ),
55 | 'roberta': (
56 | roberta.RoBERTaRanker,
57 | roberta.ROBERTA_RANKER_CONFIG,
58 | roberta.ROBERTA_RANKER_HYPER_PARAM_SPACE,
59 | roberta.ROBERTA_RANKER_TRANSFORM
60 | )
61 | }
62 | """Fine-tuned language model baselines for the scruples resource."""
63 |
--------------------------------------------------------------------------------
/src/scruples/baselines/roberta.py:
--------------------------------------------------------------------------------
1 | """RoBERTa baselines."""
2 |
3 | from transformers import (
4 | RobertaForMultipleChoice,
5 | RobertaForSequenceClassification,
6 | RobertaTokenizer)
7 | import skopt
8 | import torch
9 |
10 | from ..data.labels import Label
11 | from ..dataset.transforms import (
12 | BertTransform,
13 | Compose,
14 | Map)
15 |
16 |
17 | # the RoBERTa sequence classification baseline
18 |
19 | RoBERTaClassifier = RobertaForSequenceClassification.from_pretrained
20 | """Predict fixed classes with a fine-tuned RoBERTa model."""
21 |
22 |
23 | ROBERTA_CLASSIFIER_CONFIG = {
24 | 'model': {
25 | # N.B. pretrained_model_name_or_path for the model must be the
26 | # same as pretrained_roberta for the transform
27 | 'pretrained_model_name_or_path': 'roberta-large',
28 | 'num_labels': len(Label)
29 | },
30 | 'transform': {
31 | # N.B. pretrained_roberta for the transform must be the same as
32 | # pretrained_model_name_or_path for the model
33 | 'pretrained_roberta': 'roberta-large',
34 | 'max_sequence_length': 512,
35 | 'truncation_strategy_title': 'beginning',
36 | 'truncation_strategy_text': 'beginning'
37 | }
38 | }
39 | """Configuration for ``RoBERTaClassifier``."""
40 |
41 |
42 | ROBERTA_CLASSIFIER_HYPER_PARAM_SPACE = [
43 | skopt.space.Real(
44 | low=1e-8,
45 | high=1e-2,
46 | prior='log-uniform',
47 | name='lr'),
48 | skopt.space.Real(
49 | low=1e-5,
50 | high=1e0,
51 | prior='log-uniform',
52 | name='weight_decay'),
53 | skopt.space.Real(
54 | low=0.0,
55 | high=1.0,
56 | prior='uniform',
57 | name='warmup_proportion'),
58 | skopt.space.Integer(
59 | low=1,
60 | high=10,
61 | name='n_epochs'),
62 | skopt.space.Integer(
63 | low=3,
64 | high=10,
65 | name='log_train_batch_size')
66 | ]
67 | """The hyper-param search space for ``RoBERTaClassifier``."""
68 |
69 |
70 | ROBERTA_CLASSIFIER_TRANSFORM = (
71 | lambda
72 | pretrained_roberta,
73 | max_sequence_length,
74 | truncation_strategy_title,
75 | truncation_strategy_text:
76 | Compose([
77 | BertTransform(
78 | tokenizer=RobertaTokenizer.from_pretrained(
79 | pretrained_roberta,
80 | do_lower_case=False),
81 | max_sequence_length=max_sequence_length,
82 | truncation_strategy=(
83 | truncation_strategy_title,
84 | truncation_strategy_text),
85 | starting_sep_token=True
86 | ),
87 | lambda d: {
88 | 'input_ids': torch.tensor(d['input_ids']),
89 | 'attention_mask': torch.tensor(d['input_mask'])
90 | }
91 | ])
92 | )
93 | """The factory to create data transforms for ``RoBERTaClassifier``."""
94 |
95 |
96 | # the RoBERTa ranking baseline
97 |
98 | RoBERTaRanker = RobertaForMultipleChoice.from_pretrained
99 | """Rank choices with a softmax over a fine-tuned RoBERTa model."""
100 |
101 |
102 | ROBERTA_RANKER_CONFIG = {
103 | 'model': {
104 | # N.B. pretrained_model_name_or_path for the model must be the
105 | # same as pretrained_roberta for the transform
106 | 'pretrained_model_name_or_path': 'roberta-large'
107 | },
108 | 'transform': {
109 | # N.B. pretrained_roberta for the transform must be the same as
110 | # pretrained_model_name_or_path for the model
111 | 'pretrained_roberta': 'roberta-large',
112 | 'max_sequence_length': 90
113 | }
114 | }
115 | """Configuration for ``RoBERTaRanker``."""
116 |
117 |
118 | ROBERTA_RANKER_HYPER_PARAM_SPACE = [
119 | skopt.space.Real(
120 | low=1e-8,
121 | high=1e-2,
122 | prior='log-uniform',
123 | name='lr'),
124 | skopt.space.Real(
125 | low=1e-5,
126 | high=1e0,
127 | prior='log-uniform',
128 | name='weight_decay'),
129 | skopt.space.Real(
130 | low=0.0,
131 | high=1.0,
132 | prior='uniform',
133 | name='warmup_proportion'),
134 | skopt.space.Integer(
135 | low=1,
136 | high=25,
137 | name='n_epochs'),
138 | skopt.space.Integer(
139 | low=3,
140 | high=10,
141 | name='log_train_batch_size')
142 | ]
143 | """The hyper-param seach space for ``RoBERTaRanker``."""
144 |
145 |
146 | ROBERTA_RANKER_TRANSFORM = (
147 | lambda
148 | pretrained_roberta,
149 | max_sequence_length:
150 | Compose([
151 | # wrap each action in a tuple for passing it to BertTransform
152 | lambda actions: tuple((action, None) for action in actions),
153 | # map BertTransform across all the action choices
154 | Map(
155 | transform=BertTransform(
156 | tokenizer=RobertaTokenizer.from_pretrained(
157 | pretrained_roberta,
158 | do_lower_case=False),
159 | max_sequence_length=max_sequence_length,
160 | truncation_strategy=('beginning', 'beginning'),
161 | starting_sep_token=True
162 | )
163 | ),
164 | # collect the action choices and stack their tensors so the
165 | # choices can be their own dimension of the batch
166 | lambda ds: {
167 | 'input_ids': torch.stack([
168 | torch.tensor(d['input_ids'])
169 | for d in ds
170 | ], dim=0),
171 | 'attention_mask': torch.stack([
172 | torch.tensor(d['input_mask'])
173 | for d in ds
174 | ], dim=0)
175 | }
176 | ])
177 | )
178 | """The factory to create data transforms for ``RoBERTaRanker``."""
179 |
--------------------------------------------------------------------------------
/src/scruples/baselines/trees.py:
--------------------------------------------------------------------------------
1 | """Tree-based baselines."""
2 |
3 | from sklearn.feature_extraction.text import (
4 | CountVectorizer,
5 | TfidfTransformer)
6 | from sklearn.ensemble import RandomForestClassifier
7 | from sklearn.pipeline import Pipeline
8 | from sklearn.preprocessing import FunctionTransformer
9 |
10 | from . import utils
11 |
12 |
13 | # the random forest baseline
14 |
15 | RandomForestBaseline = Pipeline([
16 | (
17 | 'concatenator',
18 | FunctionTransformer(func=utils.concat_title_and_text, validate=False)
19 | ),
20 | (
21 | 'vectorizer',
22 | CountVectorizer(
23 | input='content',
24 | encoding='utf-8',
25 | decode_error='strict',
26 | preprocessor=None,
27 | tokenizer=None,
28 | token_pattern=r'(?u)\b\w\w+\b',
29 | max_features=None,
30 | vocabulary=None)
31 | ),
32 | (
33 | 'tfidf',
34 | TfidfTransformer(smooth_idf=True)
35 | ),
36 | (
37 | 'classifier',
38 | RandomForestClassifier(
39 | n_estimators=100,
40 | max_depth=None,
41 | max_features='auto',
42 | max_leaf_nodes=None,
43 | min_impurity_decrease=0.,
44 | oob_score=False,
45 | n_jobs=1,
46 | verbose=0,
47 | warm_start=False)
48 | )
49 | ])
50 | """Predict using a random forest on bag-of-ngrams features."""
51 |
52 | RANDOM_FOREST_HYPER_PARAMETERS = {
53 | 'vectorizer__strip_accents': ['ascii', 'unicode', None],
54 | 'vectorizer__lowercase': [True, False],
55 | 'vectorizer__stop_words': ['english', None],
56 | 'vectorizer__ngram_range': [
57 | (lo, hi)
58 | for lo in range(1, 2)
59 | for hi in range(lo, lo + 5)
60 | ],
61 | 'vectorizer__analyzer': ['word', 'char', 'char_wb'],
62 | 'vectorizer__max_df': (0.75, 1., 'uniform'),
63 | 'vectorizer__min_df': (0., 0.25, 'uniform'),
64 | 'vectorizer__binary': [True, False],
65 | 'tfidf__norm': ['l1', 'l2', None],
66 | 'tfidf__use_idf': [True, False],
67 | 'tfidf__sublinear_tf': [True, False],
68 | 'classifier__criterion': ['gini', 'entropy'],
69 | 'classifier__min_samples_split': (2, 500),
70 | 'classifier__min_samples_leaf': (1, 250),
71 | 'classifier__min_weight_fraction_leaf': (0., .25, 'uniform'),
72 | 'classifier__bootstrap': [True, False],
73 | 'classifier__class_weight': ['balanced', 'balanced_subsample', None]
74 | }
75 | """The hyper-param search space for ``RandomForestBaseline``."""
76 |
--------------------------------------------------------------------------------
/src/scruples/baselines/utils.py:
--------------------------------------------------------------------------------
1 | """Utilities for baselines on scruples."""
2 |
3 | from typing import Any, Dict
4 |
5 | import numpy as np
6 | import pandas as pd
7 | from scipy.special import gammaln
8 | from sklearn.base import (
9 | BaseEstimator,
10 | TransformerMixin)
11 | from sklearn.utils.validation import check_is_fitted
12 |
13 |
14 | # functions
15 |
16 | def concat_title_and_text(features: pd.DataFrame) -> np.ndarray:
17 | """Return the concatenation of the title and text features.
18 |
19 | Parameters
20 | ----------
21 | features : pd.DataFrame
22 | The features for the scruples dataset.
23 |
24 | Returns
25 | -------
26 | np.ndarray
27 | The concatentation of the title and text strings separated by a
28 | newline character, in a numpy array.
29 | """
30 | return (features['title'] + '\n' + features['text']).values
31 |
32 |
33 | def dirichlet_multinomial(log_alphas: np.ndarray) -> np.ndarray:
34 | """Return class probabilities from a dirichlet-multinomial model.
35 |
36 | Parameters
37 | ----------
38 | log_alphas : np.ndarray
39 | An n x k dimensional numpy array where n is the number of
40 | samples and k is the number of classes. The values of the array
41 | should correspond to the log of the alpha parameters for the
42 | predicted dirichlet distribution corresponding to each instance.
43 |
44 | Returns
45 | -------
46 | np.ndarray
47 | An n x k dimensional array giving the class probabilities
48 | corresponding to ``log_alphas`` for each sample.
49 | """
50 | alphas = np.exp(log_alphas)
51 | return alphas / np.expand_dims(np.sum(alphas, axis=-1), -1)
52 |
53 |
54 | # classes
55 |
56 | class ResourceTransformer(BaseEstimator, TransformerMixin):
57 | """Featurize the action pairs from the scruples resource.
58 |
59 | ``ResourceTransformer`` applies the same featurization pipeline
60 | (``self.transformer``) to both actions in an instance from the
61 | scruples resource and then takes the difference of their
62 | features.
63 |
64 | You can set parameters on the ``self.transformer`` attribute by
65 | prefixing parameters to ``ResourceTransformer`` with
66 | ``transformer__``.
67 |
68 | ``ResourceTransformer`` is particularly useful in front of linear
69 | models like logistic regression, since applying the model to the
70 | difference of the features is the same as taking the difference of
71 | the final scores.
72 |
73 | Attributes
74 | ----------
75 | See `Parameters`_.
76 |
77 | Parameters
78 | ----------
79 | transformer : Transformer
80 | The transformer to apply to the actions.
81 | """
82 | def __init__(
83 | self,
84 | transformer: TransformerMixin
85 | ) -> None:
86 | self.transformer = transformer
87 |
88 | def set_params(
89 | self,
90 | **params: Dict[str, Any]
91 | ) -> 'ResourceTransformer':
92 | self_params = {}
93 | transformer_params = {}
94 | for param, value in params.items():
95 | if param.startswith('transformer__'):
96 | transformer_params[param[13:]] = value
97 | else:
98 | self_params[param] = value
99 | # set the parameters on this instance
100 | super().set_params(**self_params)
101 | # set the parameters on the transformer attribute
102 | self.transformer.set_params(**transformer_params)
103 |
104 | return self
105 |
106 | def fit(
107 | self,
108 | X: pd.DataFrame,
109 | y: np.ndarray = None
110 | ) -> 'ResourceTransformer':
111 | """Fit the instance to ``X``.
112 |
113 | Fitting an instance of ``ResourceTransformer`` fits its
114 | ``self.transformer`` attribute to the data. The ``y`` argument
115 | is ignored.
116 |
117 | Parameters
118 | ----------
119 | X : pd.DataFrame
120 | The data to fit.
121 | y : None
122 | An ignored argument.
123 |
124 | Returns
125 | -------
126 | self : object
127 | The instance.
128 | """
129 | X_ = pd.concat([X['action0'], X['action1']])
130 |
131 | self.transformer.fit(X_)
132 |
133 | self._fitted = True
134 |
135 | return self
136 |
137 | def transform(
138 | self,
139 | X: pd.DataFrame
140 | ) -> Any:
141 | """Transform ``X``.
142 |
143 | Parameters
144 | ----------
145 | X : pd.DataFrame
146 | The data to transform.
147 |
148 | Returns
149 | -------
150 | Any
151 | The difference of the features for the actions derived by
152 | applying ``self.transformer`` to them.
153 | """
154 | check_is_fitted(self, '_fitted')
155 |
156 | return (
157 | self.transformer.transform(X['action1'])
158 | - self.transformer.transform(X['action0'])
159 | )
160 |
--------------------------------------------------------------------------------
/src/scruples/data/__init__.py:
--------------------------------------------------------------------------------
1 | """Definitions and utilities for working with the data."""
2 |
3 | from . import (
4 | comment,
5 | label_scores,
6 | labels,
7 | post,
8 | post_types,
9 | utils)
10 |
--------------------------------------------------------------------------------
/src/scruples/data/comment.py:
--------------------------------------------------------------------------------
1 | """A class representing a comment."""
2 |
3 | from typing import Optional
4 |
5 | import attr
6 |
7 | from .. import settings
8 | from . import utils
9 | from .labels import Label
10 |
11 |
12 | @attr.s(frozen=True, kw_only=True)
13 | class Comment:
14 | """A class representing a comment.
15 |
16 | Attributes
17 | ----------
18 | label : Optional[Label]
19 | The label expressed by the comment. If no label can be extracted
20 | or multiple labels are extracted from the comment, then this
21 | attribute is ``None``.
22 | is_top_level : bool
23 | ``True`` if the comment is a top-level comment (i.e., a direct
24 | response to a link and not another comment).
25 | has_empty_body : bool
26 | ``True`` if the body text of the comment is empty.
27 | is_deleted : bool
28 | ``True`` if the comment is deleted.
29 | is_by_automoderator : bool
30 | ``True`` if the comment is by the AutoModerator.
31 | is_spam : bool
32 | ``True`` if the comment is is classified as spam (i.e., has
33 | irrelevant or meaningless content) by a set of rules-based
34 | filters.
35 | is_good : bool
36 | ``True`` if the comment is a good candidate for contributing a
37 | label. Some comments have good content but are not candidates
38 | for inclusion in the dataset due to certain dataset design
39 | decisions, such as requiring comments to be top-level.
40 |
41 | See `Parameters`_ for additional attributes.
42 |
43 | Parameters
44 | ----------
45 | id : str
46 | A unique ID for the comment.
47 | subreddit_id : str
48 | The ID of the subreddit the comment was posted in. The ID has
49 | ``"t5_"`` prepended to it to represent the fact that it is a
50 | *subreddit* ID.
51 | subreddit : str
52 | The name of the subreddit the comment was posted in.
53 | link_id : str
54 | The ID of the post that the comment was made on. The ID has
55 | ``"t3_"`` prepended to it to represent the fact that it is a
56 | *post* ID.
57 | parent_id : str
58 | The ID of the parent object (either a comment or a post). If the
59 | parent object is a post, the ID will begin with ``"t3_"``. If
60 | the parent object is a comment, the ID will begin with
61 | ``"t1_"``.
62 | created_utc : int
63 | The time that the comment was created in seconds since the
64 | epoch.
65 | author : str
66 | The username of the comment's author.
67 | body : str
68 | The body text of the comment.
69 | score : int
70 | The score (upvotes minus downvotes) of the comment.
71 | controversiality : int
72 | The controversiality score for the comment.
73 | gilded : int
74 | The number of times the comment has been gilded.
75 | """
76 | # identifying information
77 | id: str = attr.ib(
78 | validator=attr.validators.instance_of(str),
79 | converter=str)
80 |
81 | # reddit location
82 | subreddit_id: str = attr.ib(
83 | validator=attr.validators.instance_of(str),
84 | converter=str)
85 | subreddit: str = attr.ib(
86 | validator=attr.validators.instance_of(str),
87 | converter=str)
88 | link_id: str = attr.ib(
89 | validator=attr.validators.instance_of(str),
90 | converter=str)
91 | parent_id: str = attr.ib(
92 | validator=attr.validators.instance_of(str),
93 | converter=str)
94 |
95 | # creation
96 | created_utc: int = attr.ib(
97 | validator=attr.validators.instance_of(int),
98 | converter=int)
99 | author: str = attr.ib(
100 | validator=attr.validators.instance_of(str),
101 | converter=str)
102 |
103 | # content
104 | body: str = attr.ib(
105 | validator=attr.validators.instance_of(str),
106 | converter=str)
107 |
108 | # user interactions
109 | score: int = attr.ib(
110 | validator=attr.validators.instance_of(int),
111 | converter=int)
112 | controversiality: int = attr.ib(
113 | validator=attr.validators.instance_of(int),
114 | converter=int)
115 | gilded: int = attr.ib(
116 | validator=attr.validators.instance_of(int),
117 | converter=int)
118 |
119 | # computed content properties
120 |
121 | @utils.cached_property
122 | def label(self) -> Optional[Label]:
123 | return Label.extract_from_text(self.body)
124 |
125 | # computed properties for identifying comments to count in label
126 | # scores
127 |
128 | @utils.cached_property
129 | def is_top_level(self) -> bool:
130 | return self.parent_id == self.link_id
131 |
132 | @utils.cached_property
133 | def has_empty_body(self) -> bool:
134 | return self.body == ""
135 |
136 | @utils.cached_property
137 | def is_deleted(self) -> bool:
138 | return (
139 | self.body == '[deleted]'
140 | or self.body == '[removed]'
141 | )
142 |
143 | @utils.cached_property
144 | def is_by_automoderator(self) -> bool:
145 | return self.author == settings.AUTO_MODERATOR_NAME
146 |
147 | @utils.cached_property
148 | def is_spam(self) -> bool:
149 | # N.B. place cheaper predicates earlier so short-circuiting can
150 | # avoid evaluating more expensive predicates.
151 | return (
152 | self.has_empty_body
153 | or self.is_deleted
154 | or self.is_by_automoderator
155 | )
156 |
157 | @utils.cached_property
158 | def is_good(self) -> bool:
159 | # N.B. place cheaper predicates earlier so short-circuiting can
160 | # avoid evaluating more expensive predicates.
161 | return (
162 | self.is_top_level
163 | and not self.is_spam
164 | )
165 |
--------------------------------------------------------------------------------
/src/scruples/data/label_scores.py:
--------------------------------------------------------------------------------
1 | """A class representing a collection of scores for the labels."""
2 |
3 | from typing import Dict, Optional
4 |
5 | import attr
6 |
7 | from .. import settings
8 | from . import utils
9 | from .labels import BinarizedLabel, Label
10 |
11 |
12 | @attr.s(frozen=True, kw_only=True)
13 | class LabelScores:
14 | """A class representing scores for all the labels.
15 |
16 | Attributes
17 | ----------
18 | binary_label_to_score : Dict[BinarizedLabel, int]
19 | A dictionary mapping each binarized label to its corresponding
20 | score.
21 | best_binarized_label : BinarizedLabel
22 | The overall highest scoring binarized label. Ties are broken
23 | arbitrarily.
24 | best_label : Label
25 | The overall highest scoring label. Ties are broken arbitrarily.
26 | has_all_zero_binarized_label_scores : bool
27 | ``True`` if all the binarized label scores are zero.
28 | has_all_zero_label_scores : bool
29 | ``True`` if all the label scores are zero.
30 | is_good : bool
31 | ``True`` if the label scores are considered good for inclusion
32 | in the final dataset.
33 |
34 | See `Parameters`_ for additional attributes.
35 |
36 | Parameters
37 | ----------
38 | label_to_score : Dict[Label, int]
39 | A dictionary mapping each label to its corresponding score.
40 | """
41 | label_to_score: Dict[Label, int] = attr.ib(
42 | validator=attr.validators.deep_mapping(
43 | key_validator=attr.validators.instance_of(Label),
44 | value_validator=attr.validators.instance_of(int)))
45 |
46 | # computed content properties
47 |
48 | @utils.cached_property
49 | def binarized_label_to_score(self) -> Dict[BinarizedLabel, int]:
50 | binarized_label_to_score = {
51 | binarized_label: 0
52 | for binarized_label in BinarizedLabel
53 | }
54 | for label, score in self.label_to_score.items():
55 | binarized_label = BinarizedLabel.binarize(label)
56 | if binarized_label is not None:
57 | binarized_label_to_score[binarized_label] += score
58 |
59 | return binarized_label_to_score
60 |
61 | @utils.cached_property
62 | def best_binarized_label(self) -> Optional[BinarizedLabel]:
63 | return max(
64 | self.binarized_label_to_score.items(),
65 | key=lambda t: t[1]
66 | )[0]
67 |
68 | @utils.cached_property
69 | def best_label(self) -> Label:
70 | return max(
71 | self.label_to_score.items(),
72 | key=lambda t: t[1]
73 | )[0]
74 |
75 | # computed properties for identifying good label scores
76 |
77 | @utils.cached_property
78 | def has_all_zero_binarized_label_scores(self) -> bool:
79 | return all(
80 | v == 0
81 | for v in self.binarized_label_to_score.values()
82 | )
83 |
84 | @utils.cached_property
85 | def has_all_zero_label_scores(self) -> bool:
86 | return all(v == 0 for v in self.label_to_score.values())
87 |
88 | @utils.cached_property
89 | def is_good(self) -> bool:
90 | # N.B. place cheaper predicates earlier so short-circuiting can
91 | # avoid evaluating more expensive predicates.
92 | return (
93 | not self.has_all_zero_binarized_label_scores
94 | and not self.has_all_zero_label_scores
95 | )
96 |
--------------------------------------------------------------------------------
/src/scruples/data/post_types.py:
--------------------------------------------------------------------------------
1 | """Definitions of post types and code for extracting them."""
2 |
3 | import enum
4 | from typing import (
5 | List,
6 | Optional)
7 |
8 | import regex
9 |
10 |
11 | @enum.unique
12 | class PostType(enum.Enum):
13 | """A post type.
14 |
15 | Posts are categorized into three types:
16 |
17 | 1. **HISTORICAL** : The author is asking if they're in the wrong, based
18 | on an event that has actually happened.
19 | 2. **HYPOTHETICAL** : The author is asking if they would be in the wrong,
20 | if they were to perform some action.
21 | 3. **META** : The post is discussing the subreddit itself.
22 |
23 | Attributes
24 | ----------
25 | See `Parameters`_.
26 |
27 | Parameters
28 | ----------
29 | index : int
30 | A (unique) numerical index assigned to the post type.
31 | reddit_name : str
32 | The name for the post type used on the subreddit from which the
33 | data originates.
34 | patterns : List[str]
35 | A list of strings, each representing a regular expression
36 | pattern used to extract that post type from a post's title. Note
37 | that patterns are compiled to regular expressions when they're
38 | bound to the ``PostType`` instance as an attribute.
39 | """
40 | HISTORICAL = (0, 'AITA', [
41 | r'\m(?i:AITAH?)\M',
42 | r'(?i:Am I the asshole){e<=2}'
43 | ])
44 | HYPOTHETICAL = (1, 'WIBTA', [
45 | r'\m(?i:WIBTAH?)\M',
46 | r'(?i:Would I be the asshole){e<=2}'
47 | ])
48 | META = (2, 'META', [
49 | r'\mMETA\M',
50 | r'\[(?i:META)\]'
51 | ])
52 |
53 | @classmethod
54 | def extract_from_title(
55 | cls,
56 | title: str
57 | ) -> Optional['PostType']:
58 | """Return a post type extracted from ``title`` or ``None``.
59 |
60 | If a post type can be unambiguously extracted from the title,
61 | return it; otherwise, return ``None``.
62 |
63 | Parameters
64 | ----------
65 | title : str
66 | The title string from which to extract the post type.
67 |
68 | Returns
69 | -------
70 | Optional[PostType]
71 | The extracted post type.
72 | """
73 | found_post_types = set()
74 | for post_type in cls:
75 | if post_type.in_(title):
76 | found_post_types.add(post_type)
77 |
78 | return found_post_types.pop() if len(found_post_types) == 1 else None
79 |
80 | def __init__(
81 | self,
82 | index: int,
83 | reddit_name: str,
84 | patterns: List[str]
85 | ) -> None:
86 | self.index = index
87 | self.reddit_name = reddit_name
88 | self.patterns = [regex.compile(pattern) for pattern in patterns]
89 |
90 | def in_(
91 | self,
92 | title: str
93 | ) -> bool:
94 | """Return ``True`` if ``title`` expresses the post type.
95 |
96 | Return ``True`` if the post type has any pattern that matches a
97 | substring from ``title``.
98 |
99 | Parameters
100 | ----------
101 | title : str
102 | The title string to check for the post type.
103 |
104 | Returns
105 | -------
106 | bool
107 | Whether or not the post type is in ``title``.
108 | """
109 | return any(pattern.search(title) for pattern in self.patterns)
110 |
--------------------------------------------------------------------------------
/src/scruples/data/utils.py:
--------------------------------------------------------------------------------
1 | """Utilities for working with data."""
2 |
3 | from functools import wraps
4 | from typing import (
5 | Any,
6 | Callable,
7 | Dict)
8 |
9 | import attr
10 |
11 |
12 | def instantiate_attrs_with_extra_kwargs(
13 | cls: Any,
14 | **kwargs: Dict[str, Any]
15 | ):
16 | """Return ``cls`` instantiated with ``kwargs`` ignoring extra kwargs.
17 |
18 | Parameters
19 | ----------
20 | cls : Object
21 | An object that has been decorated with ``@attr.s``.
22 | **kwargs : Dict[str, Any]
23 | Any keyword arguments to use when instantiating ``cls``. Extra
24 | keyword arguments will be ignored.
25 | """
26 | if not attr.has(cls):
27 | raise ValueError(f'{cls} must be decorated with @attr.s')
28 |
29 | attr_names = attr.fields_dict(cls).keys()
30 | return cls(**{
31 | k: kwargs[k]
32 | for k in attr_names
33 | })
34 |
35 |
36 | def cached_property(method: Callable):
37 | """Decorate a method to act as a cached property.
38 |
39 | This decorator converts a method into a cached property. It is
40 | intended to only be used on the methods of classes decorated with
41 | ``@attr.s`` where ``frozen=True``. This decorator works analogously
42 | to ``@property`` except it caches the computed value.
43 |
44 | Parameters
45 | ----------
46 | method : Callable, required
47 | The method to decorate. ``method`` should take only one
48 | argument: ``self``.
49 |
50 | Returns
51 | -------
52 | Callable
53 | The decoratored method.
54 |
55 | Notes
56 | -----
57 | When used on a frozen attrs class, values for the property may
58 | safely be cached because the object is intended to be
59 | immutable. Additionally, the best place to store these cached values
60 | is on the object itself, so that they can be garbage collected when
61 | the object is.
62 | """
63 | @wraps(method)
64 | def wrapper(self):
65 | cached_name = f'_{method.__name__}'
66 | if not hasattr(self, cached_name):
67 | value = method(self)
68 |
69 | # To get around the immutability of the instance, we have to
70 | # use __setattr__ from object.
71 | object.__setattr__(self, cached_name, value)
72 |
73 | return getattr(self, cached_name)
74 | return property(wrapper)
75 |
--------------------------------------------------------------------------------
/src/scruples/dataset/__init__.py:
--------------------------------------------------------------------------------
1 | """Classes and functions for reading the scruples dataset."""
2 |
3 | from . import (
4 | readers,
5 | transforms)
6 |
--------------------------------------------------------------------------------
/src/scruples/demos/__init__.py:
--------------------------------------------------------------------------------
1 | """Demos for scruples."""
2 |
3 | from . import (
4 | norms,
5 | scoracle)
6 |
--------------------------------------------------------------------------------
/src/scruples/demos/norms/__init__.py:
--------------------------------------------------------------------------------
1 | """A demo for reproducing normative decisions."""
2 |
3 | from . import (
4 | app,
5 | utils)
6 |
--------------------------------------------------------------------------------
/src/scruples/demos/norms/static/favicon.ico:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/scruples/c79697e97d24f5b43b5ff030a9b16011a32e4a5f/src/scruples/demos/norms/static/favicon.ico
--------------------------------------------------------------------------------
/src/scruples/demos/norms/utils.py:
--------------------------------------------------------------------------------
1 | """Utilities for the ``norms`` demo."""
2 |
3 | from typing import (
4 | Any,
5 | Callable,
6 | List,
7 | Optional,
8 | Tuple)
9 |
10 | import torch
11 |
12 |
13 | class PredictionDataset(torch.utils.data.Dataset):
14 | """A PyTorch ``Dataset`` class for prediction.
15 |
16 | Parameters
17 | ----------
18 | features : List[Any], required
19 | The list of features for the dataset's instances.
20 | transform : Optional[Callable], optional (default=None)
21 | The transformation to apply to the features. If ``None``, no
22 | transformation is applied.
23 | """
24 | def __init__(
25 | self,
26 | features: List[Any],
27 | transform: Optional[Callable] = None,
28 | ) -> None:
29 | super().__init__()
30 |
31 | self.features = features
32 | self.transform = transform
33 |
34 | def __len__(self) -> int:
35 | return len(self.features)
36 |
37 | def __getitem__(self, key: int) -> Any:
38 | feature = self.features[key]
39 |
40 | if self.transform:
41 | feature = self.transform(feature)
42 |
43 | return feature
44 |
--------------------------------------------------------------------------------
/src/scruples/demos/scoracle/__init__.py:
--------------------------------------------------------------------------------
1 | """A demo for the oracle score estimator."""
2 |
3 | from . import app
4 |
--------------------------------------------------------------------------------
/src/scruples/demos/scoracle/static/favicon.ico:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/scruples/c79697e97d24f5b43b5ff030a9b16011a32e4a5f/src/scruples/demos/scoracle/static/favicon.ico
--------------------------------------------------------------------------------
/src/scruples/extraction/__init__.py:
--------------------------------------------------------------------------------
1 | """Code and utilities for extracting data from and transforming text."""
2 |
3 | from . import (
4 | base,
5 | filters,
6 | normalizers,
7 | transformers)
8 |
--------------------------------------------------------------------------------
/src/scruples/extraction/filters.py:
--------------------------------------------------------------------------------
1 | """Filter functions."""
2 |
3 | import regex
4 |
5 | from . import base
6 | from .. import utils
7 |
8 |
9 | class EmptyStringFilter(base.LoggedCallable):
10 | """Filter out the empty string."""
11 |
12 | def apply(
13 | self,
14 | s: str
15 | ) -> bool:
16 | """Return ``True`` if ``s`` is the empty string.
17 |
18 | Parameters
19 | ----------
20 | s : str, required
21 | The string to test.
22 |
23 | Returns
24 | -------
25 | bool
26 | ``True`` if ``s`` is the empty string, otherwise ``False``.
27 | """
28 | return s == ''
29 |
30 |
31 | class TooFewCharactersFilter(base.LoggedCallable):
32 | """Filter strings with too few characters.
33 |
34 | Parameters
35 | ----------
36 | min_chars : int, required
37 | The minimum number of characters to allow.
38 | """
39 |
40 | def __init__(
41 | self,
42 | min_chars: int,
43 | *args,
44 | **kwargs
45 | ) -> None:
46 | super().__init__(*args, **kwargs)
47 |
48 | self.min_chars = min_chars
49 |
50 | def apply(
51 | self,
52 | s: str
53 | ) -> bool:
54 | """Return ``True`` if ``s`` has too few characters.
55 |
56 | Parameters
57 | ----------
58 | s : str, required
59 | The string to test.
60 |
61 | Returns
62 | -------
63 | bool
64 | ``True`` if ``s`` has fewer than ``self.min_chars``
65 | characters, otherwise ``False``.
66 | """
67 | if not isinstance(s, str):
68 | raise ValueError(f's ({s}) must be a string.')
69 |
70 | return len(s) < self.min_chars
71 |
72 |
73 | class TooFewWordsFilter(base.LoggedCallable):
74 | """Filter out strings with too few words.
75 |
76 | Parameters
77 | ----------
78 | min_words : int, required
79 | The minimum number of words to allow.
80 | """
81 |
82 | def __init__(
83 | self,
84 | min_words: int,
85 | *args,
86 | **kwargs
87 | ) -> None:
88 | super().__init__(*args, **kwargs)
89 |
90 | self.min_words = min_words
91 |
92 | def apply(
93 | self,
94 | s: str
95 | ) -> bool:
96 | """Return ``True`` if ``s`` has too few words.
97 |
98 | Parameters
99 | ----------
100 | s : str, required
101 | The string to test.
102 |
103 | Returns
104 | -------
105 | bool
106 | ``True`` if ``s`` has fewer than ``self.min_words`` words,
107 | otherwise ``False``.
108 | """
109 | if not isinstance(s, str):
110 | raise ValueError(f's ({s}) must be a string.')
111 |
112 | return utils.count_words(s) < self.min_words
113 |
114 |
115 | class PrefixFilter(base.LoggedCallable):
116 | """Filter out strings with a certain prefix.
117 |
118 | Parameters
119 | ----------
120 | prefix : str, required
121 | The prefix to look for when filtering strings.
122 | case_sensitive : bool, optional (default=False)
123 | Whether or not to match the prefix in a case-sensitive fashion.
124 | """
125 |
126 | def __init__(
127 | self,
128 | prefix: str,
129 | case_sensitive: bool = False,
130 | *args,
131 | **kwargs
132 | ) -> None:
133 | super().__init__(*args, **kwargs)
134 |
135 | self.prefix = prefix
136 | self.case_sensitive = case_sensitive
137 |
138 | def apply(
139 | self,
140 | s: str
141 | ) -> bool:
142 | """Return ``True`` if ``s`` begins with ``prefix``.
143 |
144 | Parameters
145 | ----------
146 | s : str, required
147 | The string to test.
148 |
149 | Returns
150 | -------
151 | bool
152 | ``True`` if ``s`` begins with ``self.prefix``, otherwise
153 | ``False``.
154 | """
155 | if not isinstance(s, str):
156 | raise ValueError(f's ({s}) must be a string.')
157 |
158 | if not self.case_sensitive:
159 | prefix, s = self.prefix.lower(), s.lower()
160 | else:
161 | prefix = self.prefix
162 |
163 | return s.startswith(prefix)
164 |
165 |
166 | class StartsWithGerundFilter(base.LoggedCallable):
167 | """Filter strings which have no gerunds in the first few words."""
168 |
169 | _pattern = regex.compile(
170 | r'^(?:[\p{Ps}\p{Pi}"\']?\m[^\s]+\M[\p{Pe}\p{Pf}"\']? ){0,2}'
171 | r'[\p{Pi}"\']?\m[^\s]+ing\M[\p{Pf}"\']?')
172 |
173 | def apply(
174 | self,
175 | s: str
176 | ) -> bool:
177 | """Return ``True`` if ``s`` has no gerunds in the first 3 words.
178 |
179 | Parameters
180 | ----------
181 | s : str, required
182 | The string to test.
183 |
184 | Returns
185 | -------
186 | bool
187 | ``True`` if ``s`` has no gerunds in the first three words,
188 | otherwise ``False``.
189 | """
190 | return self._pattern.match(s) is None
191 |
192 |
193 | class WhWordFilter(base.LoggedCallable):
194 | """Filter strings which start with a wh-word."""
195 |
196 | _wh_word_patterns = [
197 | 'why',
198 | 'who',
199 | 'which',
200 | 'what',
201 | 'where',
202 | 'when',
203 | 'how'
204 | ]
205 | # compile the _wh_word_patterns to regexes
206 | _wh_word_patterns = [
207 | regex.compile(r'(?i:[\p{{Pi}}"\']?\m{}\M[\p{{Pf}}"\']?)'.format(w))
208 | for w in _wh_word_patterns
209 | ]
210 |
211 | def apply(
212 | self,
213 | s: str
214 | ) -> bool:
215 | """Return ``True`` if ``s`` starts with a wh-word.
216 |
217 | Parameters
218 | ----------
219 | s : str, required
220 | The string to test.
221 |
222 | Returns
223 | -------
224 | bool
225 | ``True`` if ``s`` begins with a wh-word, otherwise
226 | ``False``.
227 | """
228 | return any(pattern.match(s) for pattern in self._wh_word_patterns)
229 |
--------------------------------------------------------------------------------
/src/scruples/extraction/transformers.py:
--------------------------------------------------------------------------------
1 | """Transformations for extracting data."""
2 |
3 | from pattern.en import (
4 | conjugate,
5 | PRESENT,
6 | INDICATIVE,
7 | PROGRESSIVE)
8 | import spacy
9 |
10 | from . import base
11 |
12 |
13 | # TODO: remove the following lines once this issue:
14 | # https://github.com/clips/pattern/issues/243, is resolved.
15 | try: conjugate('give')
16 | except: pass
17 |
18 |
19 | class GerundifyingTransformer(base.LoggedCallable):
20 | """Transform the input into a gerund phrase."""
21 |
22 | _nlp = spacy.load('en', disable=['ner'])
23 |
24 | @classmethod
25 | def _is_root(
26 | cls,
27 | token: spacy.tokens.token.Token
28 | ) -> bool:
29 | return token.dep_ == 'ROOT' and token.pos_ == 'VERB'
30 |
31 | @classmethod
32 | def _is_rootlike(
33 | cls,
34 | token: spacy.tokens.token.Token
35 | ) -> bool:
36 | return (
37 | cls._is_root(token)
38 | or any(cls._is_root(c) for c in token.conjuncts)
39 | )
40 |
41 | @classmethod
42 | def _conjugate(
43 | cls,
44 | text: str
45 | ) -> str:
46 | # handle some special cases
47 | if text in ["'m", 'm']:
48 | text = 'am'
49 | elif text == 'left':
50 | text = 'leave'
51 | else:
52 | pass
53 |
54 | return conjugate(
55 | verb=text,
56 | tense=PRESENT,
57 | person=None,
58 | number=None,
59 | mood=INDICATIVE,
60 | aspect=PROGRESSIVE)
61 |
62 | def apply(
63 | self,
64 | x: str
65 | ) -> str:
66 | token_strs = []
67 | for token in self._nlp(x):
68 | if token.sent.start != 0:
69 | # skip conjugating verbs that aren't in the first
70 | # sentence
71 | token_strs.append(token.text_with_ws)
72 | continue
73 |
74 | if self._is_rootlike(token):
75 | # conjugate the token if it is like the root verb
76 | token_str = self._conjugate(token.text)\
77 | + token.text_with_ws[len(token.text):]
78 | elif token.dep_ == 'nsubj' and self._is_rootlike(token.head):
79 | # remove the subject attached to the root
80 | continue
81 | elif token.dep_ == 'aux' and self._is_rootlike(token.head):
82 | # remove auxiliary verbs attached to the root
83 | continue
84 | elif token.text in ["n't", 'nt'] and self._is_rootlike(token.head):
85 | # fix forms of "not" coming from contractions
86 | token_str = 'not' + token.text_with_ws[len(token.text):]
87 | else:
88 | # nothing to do here
89 | token_str = token.text_with_ws
90 |
91 | if (
92 | token_str.startswith('not')
93 | and len(token_strs) > 0
94 | and token_strs[-1] in ['being', 'doing', 'having']
95 | ):
96 | # swap "not" with the preceding verb
97 | token_str, token_strs[-1] = token_strs[-1], token_str
98 |
99 | token_strs.append(token_str)
100 |
101 | return ''.join(token_strs)
102 |
--------------------------------------------------------------------------------
/src/scruples/scripts/__init__.py:
--------------------------------------------------------------------------------
1 | """Scripts for scruples."""
2 |
3 | import logging
4 |
5 | import click
6 |
7 | from . import (
8 | analyze,
9 | demo,
10 | evaluate,
11 | make)
12 | from .. import utils
13 |
14 |
15 | logger = logging.getLogger(__name__)
16 |
17 |
18 | # main function
19 |
20 | @click.group()
21 | @click.option(
22 | '--verbose',
23 | is_flag=True,
24 | help='Set the log level to DEBUG.')
25 | def scruples(verbose: bool) -> None:
26 | """The command line interface for scruples."""
27 | utils.configure_logging(verbose=verbose)
28 |
29 |
30 | # register subcommands to the command group
31 |
32 | subcommands = [
33 | analyze.analyze,
34 | demo.demo,
35 | evaluate.evaluate,
36 | make.make
37 | ]
38 |
39 | for subcommand in subcommands:
40 | scruples.add_command(subcommand)
41 |
--------------------------------------------------------------------------------
/src/scruples/scripts/analyze/__init__.py:
--------------------------------------------------------------------------------
1 | """Scripts for performing analyses."""
2 |
3 | import logging
4 |
5 | import click
6 |
7 | from . import (
8 | corpus,
9 | oracle_estimator,
10 | resource)
11 |
12 |
13 | logger = logging.getLogger(__name__)
14 |
15 |
16 | # main function
17 |
18 | @click.group()
19 | def analyze():
20 | """Run an analysis."""
21 | pass
22 |
23 |
24 | # register subcommands to the command group
25 |
26 | subcommands = [
27 | corpus.corpus,
28 | oracle_estimator.oracle_estimator,
29 | resource.resource
30 | ]
31 |
32 | for subcommand in subcommands:
33 | analyze.add_command(subcommand)
34 |
--------------------------------------------------------------------------------
/src/scruples/scripts/analyze/corpus/__init__.py:
--------------------------------------------------------------------------------
1 | """Scripts for analyzing the scruples corpus."""
2 |
3 | import logging
4 |
5 | import click
6 |
7 | from . import (
8 | extractions,
9 | human_performance,
10 | oracle_performance,
11 | predictions,
12 | statistics)
13 |
14 |
15 | logger = logging.getLogger(__name__)
16 |
17 |
18 | # main function
19 |
20 | @click.group()
21 | def corpus():
22 | """Analyze the scruples corpus."""
23 | pass
24 |
25 |
26 | # register subcommands to the command group
27 |
28 | subcommands = [
29 | extractions.extractions,
30 | human_performance.human_performance,
31 | oracle_performance.oracle_performance,
32 | predictions.predictions,
33 | statistics.statistics
34 | ]
35 |
36 | for subcommand in subcommands:
37 | corpus.add_command(subcommand)
38 |
--------------------------------------------------------------------------------
/src/scruples/scripts/analyze/corpus/human_performance.py:
--------------------------------------------------------------------------------
1 | """Estimate human performance for the scruples corpus."""
2 |
3 | import collections
4 | import json
5 | import logging
6 | import random
7 |
8 | import click
9 | import tqdm
10 |
11 | from .... import settings
12 | from ....baselines.metrics import METRICS
13 | from ....data.comment import Comment
14 | from ....data.post import Post
15 | from ....data.labels import Label
16 | from ....data.utils import instantiate_attrs_with_extra_kwargs
17 |
18 |
19 | logger = logging.getLogger(__name__)
20 |
21 |
22 | # main function
23 |
24 | @click.command()
25 | @click.argument(
26 | 'comments_path',
27 | type=click.Path(exists=True, file_okay=True, dir_okay=False))
28 | @click.argument(
29 | 'posts_path',
30 | type=click.Path(exists=True, file_okay=True, dir_okay=False))
31 | @click.argument(
32 | 'split_path',
33 | type=click.Path(exists=True, file_okay=True, dir_okay=False))
34 | @click.argument(
35 | 'output_path',
36 | type=click.Path(exists=False, file_okay=True, dir_okay=False))
37 | def human_performance(
38 | comments_path: str,
39 | posts_path: str,
40 | split_path: str,
41 | output_path: str
42 | ) -> None:
43 | """Estimate human performance on the scruples corpus.
44 |
45 | Read in the comments from COMMENTS_PATH, posts from POSTS_PATH, and
46 | the split from SPLIT_PATH, then estimate human performance metrics
47 | and write them to OUTPUT_PATH.
48 |
49 | Human performance is estimated by holding out one comment from each
50 | post, and using its label as the human prediction. If the post has
51 | only a single comment, then the majority label is used as the
52 | prediction.
53 | """
54 | logger.info('Reading in comments.')
55 |
56 | link_id_to_comments = collections.defaultdict(list)
57 | with click.open_file(comments_path, 'r') as comments_file:
58 | for ln in tqdm.tqdm(comments_file.readlines(), **settings.TQDM_KWARGS):
59 | comment = instantiate_attrs_with_extra_kwargs(
60 | Comment,
61 | **json.loads(ln))
62 |
63 | # IDs are usually prefixed with something like "t1_",
64 | # "t2_", etc. to denote what type of object it is. Slice
65 | # off the first 3 characters to remove this prefix from
66 | # the link id because it will not be on the posts' IDs
67 | # when we join the comments to them.
68 | link_id_to_comments[comment.link_id[3:]].append(comment)
69 |
70 | logger.info('Reading in posts.')
71 |
72 | split_post_ids = set()
73 | with click.open_file(split_path, 'r') as split_file:
74 | for ln in split_file:
75 | split_post_ids.add(json.loads(ln)['post_id'])
76 |
77 | posts = []
78 | with click.open_file(posts_path, 'r') as posts_file:
79 | for ln in tqdm.tqdm(posts_file.readlines(), **settings.TQDM_KWARGS):
80 | kwargs = json.loads(ln)
81 | post = instantiate_attrs_with_extra_kwargs(
82 | Post,
83 | comments=link_id_to_comments[kwargs['id']],
84 | **kwargs)
85 |
86 | if post.id in split_post_ids:
87 | posts.append(post)
88 |
89 | logger.info('Computing human performance.')
90 |
91 | human_preds = []
92 | gold_labels = []
93 | for post in tqdm.tqdm(posts, **settings.TQDM_KWARGS):
94 | post_labels = [
95 | comment.label.index
96 | for comment in post.comments
97 | if comment.is_good and comment.label is not None
98 | ]
99 | random.shuffle(post_labels)
100 |
101 | if len(post_labels) > 1:
102 | human_preds.append(post_labels[-1])
103 | gold_labels.append(collections.Counter(
104 | post_labels[:-1]).most_common(1)[0][0])
105 | elif len(post_labels) == 1:
106 | # predict with the majority label
107 | human_preds.append(Label.OTHER.index)
108 | gold_labels.append(post_labels[0])
109 | else:
110 | raise ValueError('Found a post without a label.')
111 |
112 | with open(output_path, 'w') as metrics_file:
113 | json.dump({
114 | key: metric(
115 | y_true=gold_labels,
116 | y_pred=human_preds)
117 | for key, (_, metric, scorer_kwargs) in METRICS.items()
118 | if not scorer_kwargs['needs_proba']
119 | }, metrics_file)
120 |
--------------------------------------------------------------------------------
/src/scruples/scripts/analyze/corpus/oracle_performance.py:
--------------------------------------------------------------------------------
1 | """Estimate the oracle performance on the scruples corpus."""
2 |
3 | import json
4 | import logging
5 |
6 | import click
7 | import numpy as np
8 | from scipy import stats
9 | import tqdm
10 |
11 | from .... import utils, settings
12 | from ....data.labels import Label
13 | from ....baselines.metrics import METRICS
14 |
15 |
16 | logger = logging.getLogger(__name__)
17 |
18 |
19 | # constants
20 |
21 | REPORT_TEMPLATE =\
22 | """Scruples Corpus Oracle Performance
23 | ==================================
24 | Oracle performance on the scruples corpus.
25 |
26 |
27 | Oracle Metrics
28 | --------------
29 | The following metrics provide estimates of the oracle performance, i.e. the
30 | performance in each metric attained when you use the _true_ underlying label
31 | distribution for that instance to predict the labels.
32 |
33 | We obtain these estimates using an empirical Bayesian methodology: first, we
34 | fit a dirichlet-multinomial model to the data, where each instance's label
35 | distribution is drawn from a dirichlet prior. Then, for each instance we
36 | condition on the observed labels and compute the expected performance of
37 | predicting with the true distribution over the posterior.
38 |
39 | {metrics_report}
40 | """
41 |
42 |
43 | # main function
44 |
45 | @click.command()
46 | @click.argument(
47 | 'dataset_path',
48 | type=click.Path(exists=True, file_okay=True, dir_okay=False))
49 | @click.argument(
50 | 'output_path',
51 | type=click.Path(exists=False, file_okay=True, dir_okay=False))
52 | def oracle_performance(
53 | dataset_path: str,
54 | output_path: str
55 | ) -> None:
56 | """Estimate oracle performance and write a report.
57 |
58 | Read in the dataset from DATASET_PATH, estimate the oracle
59 | performance and write the results to OUTPUT_PATH.
60 | """
61 | label_name_to_idx = {
62 | label.name: label.index
63 | for label in Label
64 | }
65 | # Step 1: Read in the dataset.
66 | with click.open_file(dataset_path, 'r') as dataset_file:
67 | labels = []
68 | label_scores = []
69 | for ln in dataset_file:
70 | row = json.loads(ln)
71 |
72 | labels.append(label_name_to_idx[row['label']])
73 |
74 | scores = [0 for _ in Label]
75 | for label_name, score in row['label_scores'].items():
76 | scores[label_name_to_idx[label_name]] = score
77 | label_scores.append(scores)
78 | labels = np.array(labels)
79 | label_scores = np.array(label_scores)
80 |
81 | # Step 2: Estimate the dirichlet-multinomial parameters.
82 | params = utils.estimate_dirichlet_multinomial_parameters(
83 | label_scores)
84 |
85 | # Step 3: Estimate the expected performance.
86 | metric_name_to_value = {}
87 | for name, metric, scorer_kwargs in METRICS.values():
88 | logger.info(f'Computing estimate for {name}.')
89 | # estimate performance on the usual metrics
90 | value_samples = []
91 | for _ in tqdm.tqdm(range(10000), **settings.TQDM_KWARGS):
92 | true_scores = np.array([
93 | stats.dirichlet.rvs([
94 | a + x
95 | for a, x in zip(params, scores)
96 | ], size=1)[0]
97 | for scores in label_scores
98 | ])
99 | value = metric(
100 | y_true=labels,
101 | y_pred=true_scores
102 | if scorer_kwargs['needs_proba']
103 | else np.argmax(true_scores, axis=1))
104 | value_samples.append(value)
105 | metric_name_to_value[name] = np.mean(value_samples)
106 | if 'xentropy' in metric_name_to_value:
107 | raise ValueError(
108 | 'METRICS should not have a key named "xentropy". This issue'
109 | ' is a bug in the library, please notify the maintainers.')
110 |
111 | # estimate performance on the xentropy, which requires soft
112 | # ground-truth labels
113 | logger.info('Computing estimate for xentropy.')
114 | value_samples = []
115 | for _ in tqdm.tqdm(range(10000), **settings.TQDM_KWARGS):
116 | true_scores = np.array([
117 | stats.dirichlet.rvs([
118 | a + x
119 | for a, x in zip(params, scores)
120 | ], size=1)[0]
121 | for scores in label_scores
122 | ])
123 | value = utils.xentropy(
124 | y_true=label_scores / np.sum(label_scores, axis=1).reshape(-1, 1),
125 | y_pred=true_scores)
126 | value_samples.append(value)
127 | metric_name_to_value['xentropy'] = np.mean(value_samples)
128 |
129 | # Step 4: Write the report.
130 | metric_name_width = 1 + max(
131 | len(name)
132 | for name in metric_name_to_value.keys())
133 | metrics_report = '\n'.join(
134 | f'{name: <{metric_name_width}}: {value:.4f}'
135 | for name, value in metric_name_to_value.items())
136 | with click.open_file(output_path, 'w') as output_file:
137 | output_file.write(
138 | REPORT_TEMPLATE.format(
139 | metrics_report=metrics_report))
140 |
--------------------------------------------------------------------------------
/src/scruples/scripts/analyze/oracle_estimator.py:
--------------------------------------------------------------------------------
1 | """Simulation experiments for the oracle performance estimator."""
2 |
3 | import collections
4 | import json
5 | import logging
6 |
7 | import click
8 | import numpy as np
9 | from scipy import stats
10 | from sklearn import metrics
11 | import tqdm
12 |
13 | from scruples import settings, utils
14 | from scruples.data.labels import Label
15 | from scruples.dataset.readers import ScruplesCorpus
16 |
17 |
18 | logger = logging.getLogger(__name__)
19 |
20 |
21 | METRICS = [
22 | (
23 | 'accuracy',
24 | metrics.accuracy_score,
25 | lambda pss: np.argmax(pss, axis=-1)
26 | ),
27 | (
28 | 'f1 (macro)',
29 | lambda y_pred, y_true: metrics.f1_score(
30 | y_pred=y_pred, y_true=y_true,
31 | average='macro'),
32 | lambda pss: np.argmax(pss, axis=-1)
33 | ),
34 | (
35 | 'xentropy',
36 | utils.xentropy,
37 | lambda pss: pss
38 | )
39 | ]
40 | """The metrics to run in the experiment.
41 |
42 | Each tuple contains:
43 |
44 | (name, metric, make_predictions)
45 |
46 | """
47 |
48 |
49 | @click.command()
50 | @click.argument(
51 | 'corpus_dir',
52 | type=click.Path(exists=True, file_okay=False, dir_okay=True))
53 | @click.argument(
54 | 'output_path',
55 | type=click.Path(exists=False, file_okay=True, dir_okay=False))
56 | def oracle_estimator(
57 | corpus_dir: str,
58 | output_path: str
59 | ) -> None:
60 | """Conduct simulation experiments for the oracle estimator.
61 |
62 | Read the corpus from CORPUS_DIR, conduct different simulation experiments
63 | to evaluate the robustness and effectiveness of the oracle performance
64 | estimator, and write the results to OUTPUT_PATH.
65 | """
66 | # Read in the corpus.
67 |
68 | logger.info(f'Reading data from {corpus_dir}.')
69 |
70 | _, _, _, label_scores = ScruplesCorpus(data_dir=corpus_dir).dev
71 | label_scores = label_scores.values
72 |
73 | n_total = len(label_scores)
74 | alphas = utils.estimate_dirichlet_multinomial_parameters(label_scores)
75 |
76 | # Create the scenarios to simulate.
77 |
78 | logger.info('Creating simulation scenarios.')
79 |
80 | scenarios = []
81 | # scenario 1: running the estimator on the corpus
82 | ns = np.sum(label_scores, axis=-1)
83 | pss = stats.dirichlet.rvs(alpha=alphas, size=n_total)
84 | ys = np.array([stats.multinomial.rvs(n, ps) for n, ps in zip(ns, pss)])
85 | scenarios.append({
86 | 'name': 'corpus',
87 | 'ns': ns,
88 | 'pss': pss,
89 | 'ys': ys
90 | })
91 | # scenario 2: running the estimator with 3 annotations per example
92 | ns = 3 * np.ones(n_total)
93 | pss = stats.dirichlet.rvs(alpha=alphas, size=n_total)
94 | ys = np.array([stats.multinomial.rvs(n, ps) for n, ps in zip(ns, pss)])
95 | scenarios.append({
96 | 'name': '3 annotations',
97 | 'ns': ns,
98 | 'pss': pss,
99 | 'ys': ys
100 | })
101 | # scenario 3: running the estimator with a non-dirichlet prior
102 | ns = np.sum(label_scores, axis=-1)
103 | pss = np.concatenate([
104 | stats.dirichlet.rvs(alpha=[2, 1, 1, 1, 1], size=n_total//3),
105 | stats.dirichlet.rvs(alpha=[1, 1, 3, 1, 1], size=n_total//3),
106 | stats.dirichlet.rvs(alpha=[1, 1, 1, 1, 2], size=n_total - 2 * n_total//3)
107 | ])
108 | np.random.shuffle(pss)
109 | ys = np.array([stats.multinomial.rvs(n, ps) for n, ps in zip(ns, pss)])
110 | scenarios.append({
111 | 'name': 'non-dirichlet prior',
112 | 'ns': ns,
113 | 'pss': pss,
114 | 'ys': ys
115 | })
116 |
117 | # Run the simulations.
118 |
119 | logger.info('Running simulations.')
120 |
121 | results = collections.defaultdict(dict)
122 | for scenario in tqdm.tqdm(scenarios, **settings.TQDM_KWARGS):
123 | name = scenario['name']
124 | ns = scenario['ns']
125 | pss = scenario['pss']
126 | ys = scenario['ys']
127 |
128 | results[name]['oracle'] = {
129 | metric_name: metric(
130 | y_pred=make_predictions(pss),
131 | y_true=make_predictions(
132 | ys / np.expand_dims(np.sum(ys, axis=-1), axis=-1)))
133 | for metric_name, metric, make_predictions in METRICS
134 | }
135 |
136 | # estimate oracle performance
137 | estimated_alphas = utils.estimate_dirichlet_multinomial_parameters(ys)
138 |
139 | results[name]['estimate'] = {
140 | metric_name: utils.oracle_performance(
141 | ys=ys,
142 | metric=metric,
143 | make_predictions=make_predictions,
144 | n_samples=10000)[0]
145 | for metric_name, metric, make_predictions in tqdm.tqdm(
146 | METRICS, **settings.TQDM_KWARGS)
147 | }
148 |
149 | # Write the results to disk.
150 |
151 | logger.info(f'Writing results to {output_path}.')
152 |
153 | with click.open_file(output_path, 'w') as output_file:
154 | json.dump(results, output_file)
155 |
--------------------------------------------------------------------------------
/src/scruples/scripts/analyze/resource/__init__.py:
--------------------------------------------------------------------------------
1 | """Scripts for analyzing the scruples resource."""
2 |
3 | import logging
4 |
5 | import click
6 |
7 | from . import (
8 | human_performance,
9 | latent_traits,
10 | oracle_performance,
11 | predictions,
12 | topics,
13 | verbs)
14 |
15 |
16 | logger = logging.getLogger(__name__)
17 |
18 |
19 | # main function
20 |
21 | @click.group()
22 | def resource():
23 | """Analyze the scruples resource."""
24 | pass
25 |
26 |
27 | # register subcommands to the command group
28 |
29 | subcommands = [
30 | human_performance.human_performance,
31 | latent_traits.latent_traits,
32 | oracle_performance.oracle_performance,
33 | predictions.predictions,
34 | topics.topics,
35 | verbs.verbs
36 | ]
37 |
38 | for subcommand in subcommands:
39 | resource.add_command(subcommand)
40 |
--------------------------------------------------------------------------------
/src/scruples/scripts/analyze/resource/human_performance.py:
--------------------------------------------------------------------------------
1 | """Estimate human performance for the scruples resource."""
2 |
3 | import json
4 | import logging
5 |
6 | import click
7 |
8 | from ....baselines.metrics import METRICS
9 |
10 |
11 | logger = logging.getLogger(__name__)
12 |
13 |
14 | # main function
15 |
16 | @click.command()
17 | @click.argument(
18 | 'split_path',
19 | type=click.Path(exists=True, file_okay=True, dir_okay=False))
20 | @click.argument(
21 | 'output_path',
22 | type=click.Path(exists=False, file_okay=True, dir_okay=False))
23 | def human_performance(
24 | split_path: str,
25 | output_path: str
26 | ) -> None:
27 | """Estimate human performance on the scruples resource.
28 |
29 | Read in the split from SPLIT_PATH, then estimate human performance
30 | metrics and write them to OUTPUT_PATH.
31 |
32 | Human performance is computed by comparing the majority vote label
33 | of the human performance annotators to the majority vote label of
34 | the gold annotators.
35 | """
36 | logger.info('Computing human performance.')
37 |
38 | human_preds = []
39 | gold_labels = []
40 | with click.open_file(split_path, 'r') as split_file:
41 | for ln in split_file:
42 | row = json.loads(ln)
43 | human_preds.append(row['human_perf_label'])
44 | gold_labels.append(row['gold_label'])
45 |
46 | with open(output_path, 'w') as metrics_file:
47 | json.dump({
48 | key: metric(
49 | y_true=gold_labels,
50 | y_pred=human_preds)
51 | for key, (_, metric, scorer_kwargs) in METRICS.items()
52 | if not scorer_kwargs['needs_proba']
53 | }, metrics_file)
54 |
--------------------------------------------------------------------------------
/src/scruples/scripts/analyze/resource/oracle_performance.py:
--------------------------------------------------------------------------------
1 | """Estimate the oracle performance on the scruples resource."""
2 |
3 | import json
4 | import logging
5 |
6 | import click
7 | import numpy as np
8 | from scipy import stats
9 | import tqdm
10 |
11 | from .... import utils, settings
12 | from ....baselines.metrics import METRICS
13 |
14 |
15 | logger = logging.getLogger(__name__)
16 |
17 |
18 | # constants
19 |
20 | REPORT_TEMPLATE =\
21 | """Scruples Resource Oracle Performance
22 | ====================================
23 | Oracle performance on the scruples resource.
24 |
25 |
26 | Oracle Metrics
27 | --------------
28 | The following metrics provide estimates of the oracle performance, i.e. the
29 | performance in each metric attained when you use the _true_ underlying label
30 | distribution for that instance to predict the labels.
31 |
32 | We obtain these estimates using an empirical Bayesian methodology: first, we
33 | fit a dirichlet-multinomial model to the data, where each instance's label
34 | distribution is drawn from a dirichlet prior. Then, for each instance we
35 | condition on the observed labels and compute the expected performance of
36 | predicting with the true distribution over the posterior.
37 |
38 | {metrics_report}
39 | """
40 |
41 |
42 | # main function
43 |
44 | @click.command()
45 | @click.argument(
46 | 'dataset_path',
47 | type=click.Path(exists=True, file_okay=True, dir_okay=False))
48 | @click.argument(
49 | 'output_path',
50 | type=click.Path(exists=False, file_okay=True, dir_okay=False))
51 | def oracle_performance(
52 | dataset_path: str,
53 | output_path: str
54 | ) -> None:
55 | """Estimate oracle performance and write a report.
56 |
57 | Read in the dataset from DATASET_PATH, estimate the oracle
58 | performance and write the results to OUTPUT_PATH.
59 | """
60 | # Step 1: Read in the dataset.
61 | with click.open_file(dataset_path, 'r') as dataset_file:
62 | labels = []
63 | label_scores = []
64 | label_annotations_all = []
65 | for ln in dataset_file:
66 | row = json.loads(ln)
67 |
68 | labels.append(row['gold_label'])
69 | label_scores.append(row['gold_annotations'])
70 | label_annotations_all.append(
71 | [
72 | i + j
73 | for i, j in zip(
74 | row['gold_annotations'],
75 | row['human_perf_annotations']
76 | )
77 | ])
78 |
79 | labels = np.array(labels)
80 | label_scores = np.array(label_scores)
81 | label_annotations_all = np.array(label_annotations_all)
82 |
83 | # Step 2: Estimate the dirichlet-multinomial parameters.
84 | params = utils.estimate_dirichlet_multinomial_parameters(
85 | label_annotations_all)
86 |
87 | # Step 3: Estimate the expected performance.
88 | metric_name_to_value = {}
89 | for name, metric, scorer_kwargs in METRICS.values():
90 | logger.info(f'Computing estimate for {name}.')
91 | # estimate performance on the usual metrics
92 | value_samples = []
93 | for _ in tqdm.tqdm(range(10000), **settings.TQDM_KWARGS):
94 | true_scores = np.array([
95 | stats.dirichlet.rvs([
96 | a + x
97 | for a, x in zip(params, scores)
98 | ], size=1)[0]
99 | for scores in label_scores
100 | ])
101 | value = metric(
102 | y_true=labels,
103 | y_pred=true_scores
104 | if scorer_kwargs['needs_proba']
105 | else np.argmax(true_scores, axis=1))
106 | value_samples.append(value)
107 | metric_name_to_value[name] = np.mean(value_samples)
108 | if 'xentropy' in metric_name_to_value:
109 | raise ValueError(
110 | 'METRICS should not have a key named "xentropy". This issue'
111 | ' is a bug in the library, please notify the maintainers.')
112 |
113 | # estimate performance on the xentropy, which requires soft
114 | # ground-truth labels
115 | logger.info('Computing estimate for xentropy.')
116 | value_samples = []
117 | for _ in tqdm.tqdm(range(10000), **settings.TQDM_KWARGS):
118 | true_scores = np.array([
119 | stats.dirichlet.rvs([
120 | a + x
121 | for a, x in zip(params, scores)
122 | ], size=1)[0]
123 | for scores in label_scores
124 | ])
125 | value = utils.xentropy(
126 | y_true=label_scores / np.sum(label_scores, axis=1).reshape(-1, 1),
127 | y_pred=true_scores)
128 | value_samples.append(value)
129 | metric_name_to_value['xentropy'] = np.mean(value_samples)
130 |
131 | # Step 4: Write the report.
132 | metric_name_width = 1 + max(
133 | len(name)
134 | for name in metric_name_to_value.keys())
135 | metrics_report = '\n'.join(
136 | f'{name: <{metric_name_width}}: {value:.4f}'
137 | for name, value in metric_name_to_value.items())
138 | with click.open_file(output_path, 'w') as output_file:
139 | output_file.write(
140 | REPORT_TEMPLATE.format(
141 | metrics_report=metrics_report))
142 |
--------------------------------------------------------------------------------
/src/scruples/scripts/analyze/resource/topics.py:
--------------------------------------------------------------------------------
1 | """Infer topics for the action descriptions in the resource."""
2 |
3 | import json
4 | import logging
5 |
6 | import click
7 | from sklearn.decomposition import LatentDirichletAllocation
8 | from sklearn.feature_extraction.text import CountVectorizer
9 | from sklearn.pipeline import Pipeline
10 |
11 |
12 | logger = logging.getLogger(__name__)
13 |
14 |
15 | # constants
16 |
17 | N_COMPONENTS = 5
18 | """The number of topics to fit."""
19 | # N.B. This value was chosen interactively by examining the resulting
20 | # topics and log-likelihood the topic model produces on the resource's
21 | # dev set for various values of n_components.
22 |
23 |
24 | # main function
25 |
26 | @click.command()
27 | @click.argument(
28 | 'split_path',
29 | type=click.Path(exists=True, file_okay=True, dir_okay=False))
30 | @click.argument(
31 | 'output_path',
32 | type=click.Path(exists=False, file_okay=True, dir_okay=False))
33 | def topics(
34 | split_path: str,
35 | output_path: str
36 | ) -> None:
37 | """Create topics for the action descriptions.
38 |
39 | Read in the actions split from SPLIT_PATH, apply a Latent Dirichlet
40 | Allocation based topic model to the descriptions of all the actions,
41 | then write out the top words from the resulting topics to
42 | OUTPUT_PATH.
43 | """
44 | logger.info('Fitting the topic model to the action descriptions.')
45 |
46 | model = Pipeline([
47 | (
48 | 'vectorizer',
49 | CountVectorizer(
50 | input='content',
51 | encoding='utf-8',
52 | decode_error='strict',
53 | strip_accents=None,
54 | lowercase=True,
55 | preprocessor=None,
56 | tokenizer=None,
57 | stop_words='english',
58 | token_pattern=r'(?u)\b\w\w+\b',
59 | ngram_range=(1, 1),
60 | analyzer='word',
61 | max_df=1.0,
62 | min_df=2,
63 | max_features=None,
64 | vocabulary=None,
65 | binary=False,
66 | )
67 | ),
68 | (
69 | 'topic_model',
70 | LatentDirichletAllocation(n_components=N_COMPONENTS)
71 | )
72 | ])
73 |
74 | with click.open_file(split_path, 'r') as split_file:
75 | descriptions = [
76 | action['description']
77 | for ln in split_file
78 | for action in json.loads(ln)['actions']
79 | ]
80 |
81 | model.fit(descriptions)
82 |
83 | feature_names = model.named_steps['vectorizer'].get_feature_names()
84 | components = model.named_steps['topic_model'].components_
85 |
86 | topics_ = [
87 | [
88 | {
89 | 'word': word,
90 | 'alpha': alpha
91 | }
92 | for alpha, word in sorted(
93 | zip(component, feature_names)
94 | )[-25:][::-1]
95 | ]
96 | for component in components
97 | ]
98 |
99 | logger.info("Writing out the topics' top words.")
100 |
101 | with click.open_file(output_path, 'w') as output_file:
102 | for topic in topics_:
103 | output_file.write(json.dumps(topic) + '\n')
104 |
105 | logger.info("Finished topic analysis.")
106 |
--------------------------------------------------------------------------------
/src/scruples/scripts/demo/__init__.py:
--------------------------------------------------------------------------------
1 | """Scripts for demos."""
2 |
3 | import logging
4 |
5 | import click
6 |
7 | from . import (
8 | norms,
9 | scoracle)
10 |
11 |
12 | logger = logging.getLogger(__name__)
13 |
14 |
15 | # main function
16 |
17 | @click.group()
18 | def demo():
19 | """Run a demo's server."""
20 | pass
21 |
22 |
23 | # register subcommands to the command group
24 |
25 | subcommands = [
26 | norms.norms,
27 | scoracle.scoracle,
28 | ]
29 |
30 | for subcommand in subcommands:
31 | demo.add_command(subcommand)
32 |
--------------------------------------------------------------------------------
/src/scruples/scripts/demo/norms.py:
--------------------------------------------------------------------------------
1 | """Run the norms demo's server."""
2 |
3 | import logging
4 |
5 | import click
6 | from gevent.pywsgi import WSGIServer
7 |
8 | from ... import settings
9 | from ...demos.norms.app import (
10 | app,
11 | get_device,
12 | load_model)
13 |
14 |
15 | logger = logging.getLogger(__name__)
16 |
17 |
18 | @click.command()
19 | @click.option(
20 | '--expose', is_flag=True,
21 | help='Whether to expose the server to the internet, or run on localhost.')
22 | @click.option(
23 | '--port', type=int, default=5050,
24 | help='The port on which to serve the demo. Defaults to 5050.')
25 | def norms(
26 | expose: bool,
27 | port: int
28 | ) -> None:
29 | """Serve the norms demo.
30 |
31 | In order to run this server, you must set the following environment
32 | variables:
33 |
34 | \b
35 | SCRUPLES_NORMS_ACTIONS_BASELINE : The baseline to use for the resource
36 | (actions).
37 | SCRUPLES_NORMS_ACTIONS_MODEL : The path to the saved pretrained
38 | model to use for predicting the actions.
39 | SCRUPLES_NORMS_CORPUS_BASELINE : The baseline to use for the corpus.
40 | SCRUPLES_NORMS_CORPUS_MODEL : The path to the saved pretrained
41 | model to use for predicting the corpus.
42 | SCRUPLES_NORMS_PREDICT_BATCH_SIZE : The batch size to use for
43 | prediction.
44 | SCRUPLES_NORMS_GPU_IDS : A comma separated list of GPU IDs to
45 | use. If none are provided, then the CPU will be used instead.
46 |
47 | """
48 | # load the device and model
49 |
50 | get_device()
51 |
52 | logger.info('Loading the Actions model.')
53 | load_model(dataset='resource',)
54 |
55 | logger.info('Loading the Corpus model.')
56 | load_model(dataset='corpus')
57 |
58 | # start the server
59 |
60 | ip = '0.0.0.0' if expose else '127.0.0.1'
61 |
62 | logger.info(f'Running norms server on http://{ip}:{port}/')
63 |
64 | WSGIServer((ip, port), app).serve_forever()
65 |
--------------------------------------------------------------------------------
/src/scruples/scripts/demo/scoracle.py:
--------------------------------------------------------------------------------
1 | """Run the scoracle demo's server."""
2 |
3 | import logging
4 |
5 | import click
6 | from gevent.pywsgi import WSGIServer
7 |
8 | from scruples.demos.scoracle.app import app
9 |
10 |
11 | logger = logging.getLogger(__name__)
12 |
13 |
14 | @click.command()
15 | @click.option(
16 | '--expose', is_flag=True,
17 | help='Whether to expose the server to the internet, or run on localhost.')
18 | @click.option(
19 | '--port', type=int, default=5000,
20 | help='The port on which to serve the demo. Defaults to 5000.')
21 | def scoracle(
22 | expose: bool,
23 | port: int
24 | ) -> None:
25 | """Serve the scoracle demo."""
26 | ip = '0.0.0.0' if expose else '127.0.0.1'
27 |
28 | logger.info(f'Running scoracle server on http://{ip}:{port}/')
29 |
30 | WSGIServer((ip, port), app).serve_forever()
31 |
--------------------------------------------------------------------------------
/src/scruples/scripts/evaluate/__init__.py:
--------------------------------------------------------------------------------
1 | """Scripts for evaluating models on scruples."""
2 |
3 | import logging
4 |
5 | import click
6 |
7 | from . import (
8 | resource,
9 | corpus)
10 |
11 |
12 | logger = logging.getLogger(__name__)
13 |
14 |
15 | # main function
16 |
17 | @click.group()
18 | def evaluate():
19 | """Evaluate models on scruples."""
20 | pass
21 |
22 |
23 | # register subcommands to the command group
24 |
25 | subcommands = [
26 | resource.resource,
27 | corpus.corpus
28 | ]
29 |
30 | for subcommand in subcommands:
31 | evaluate.add_command(subcommand)
32 |
--------------------------------------------------------------------------------
/src/scruples/scripts/evaluate/corpus/__init__.py:
--------------------------------------------------------------------------------
1 | """Scripts for evaluating models on the scruples corpus."""
2 |
3 | import logging
4 |
5 | import click
6 |
7 | from . import (
8 | predict_lm,
9 | run_shallow,
10 | tune_lm)
11 |
12 |
13 | logger = logging.getLogger(__name__)
14 |
15 |
16 | # main function
17 |
18 | @click.group()
19 | def corpus():
20 | """Evaluate baseline models on the scruples corpus."""
21 | pass
22 |
23 |
24 | # register subcommands to the command group
25 |
26 | subcommands = [
27 | predict_lm.predict_lm,
28 | run_shallow.run_shallow,
29 | tune_lm.tune_lm
30 | ]
31 |
32 | for subcommand in subcommands:
33 | corpus.add_command(subcommand)
34 |
--------------------------------------------------------------------------------
/src/scruples/scripts/evaluate/corpus/run_shallow.py:
--------------------------------------------------------------------------------
1 | """Run shallow baseline models on the scruples corpus."""
2 |
3 | import collections
4 | import json
5 | import logging
6 | import os
7 | from typing import List
8 |
9 | import click
10 | import dill
11 | from sklearn.metrics import make_scorer
12 | from skopt import BayesSearchCV
13 | import tqdm
14 |
15 | from .... import settings, baselines
16 | from ....baselines.metrics import METRICS
17 | from ....dataset.readers import ScruplesCorpus
18 |
19 |
20 | logger = logging.getLogger(__name__)
21 |
22 |
23 | # main function
24 |
25 | @click.command()
26 | @click.argument(
27 | 'data_dir',
28 | type=click.Path(exists=True, file_okay=False, dir_okay=True))
29 | @click.argument(
30 | 'results_dir',
31 | type=click.Path(exists=False, file_okay=False, dir_okay=True))
32 | @click.argument(
33 | 'splits', type=click.Choice(ScruplesCorpus.SPLITS), nargs=-1)
34 | @click.option(
35 | '--metric',
36 | type=click.Choice(METRICS.keys()),
37 | default='log_loss',
38 | help='The metric to use for hyper-parameter tuning. Defaults to'
39 | ' log_loss.')
40 | @click.option(
41 | '--n-iter', type=int, default=128,
42 | help='The number of iterations of Bayesian optimization to run when'
43 | ' tuning baseline hyper-parameters. Defaults to 128.')
44 | @click.option(
45 | '--n-points', type=int, default=8,
46 | help='The number of points to evaluate in parallel during Bayesian'
47 | ' optimization. Defaults to 8.')
48 | @click.option(
49 | '--n-folds', type=int, default=4,
50 | help='The number of cross-validation folds to use. Defaults to 4.')
51 | @click.option(
52 | '--n-jobs', type=int, default=0,
53 | help='The number of parallel processes to use for tuning'
54 | ' hyper-parameters. At most n_folds * n_points processses can'
55 | ' be used at a given time. If 0, then the same number of'
56 | ' processes as CPUs will be used. Defaults to 0.')
57 | def run_shallow(
58 | data_dir: str,
59 | results_dir: str,
60 | splits: List[str],
61 | metric: str,
62 | n_iter: int,
63 | n_points: int,
64 | n_folds: int,
65 | n_jobs: int
66 | ) -> None:
67 | """Evaluate shallow baselines on the scruples corpus.
68 |
69 | Train shallow baseline models on the scruples corpus, reading the
70 | dataset from DATA_DIR, and writing trained models, logs, and other
71 | results to RESULTS_DIR. Performance is reported for each split
72 | provided as an argument.
73 | """
74 | # Step 1: Manage and construct paths.
75 |
76 | logger.info('Creating the results directory.')
77 |
78 | os.makedirs(results_dir)
79 | model_paths = {}
80 | metrics_paths = collections.defaultdict(dict)
81 | predictions_paths = collections.defaultdict(dict)
82 | for baseline in baselines.corpus.SHALLOW_BASELINES.keys():
83 | os.makedirs(os.path.join(results_dir, baseline))
84 | model_paths[baseline] = os.path.join(
85 | results_dir, baseline, 'model.pkl')
86 | for split in splits:
87 | os.makedirs(os.path.join(results_dir, baseline, split))
88 | metrics_paths[baseline][split] = os.path.join(
89 | results_dir, baseline, split, 'metrics.json')
90 | predictions_paths[baseline][split] = os.path.join(
91 | results_dir, baseline, split, 'predictions.jsonl')
92 |
93 | # Step 2: Load the data.
94 |
95 | logger.info(f'Loading data from {data_dir}.')
96 |
97 | dataset = ScruplesCorpus(data_dir=data_dir)
98 |
99 | # Step 3: Run the baselines.
100 |
101 | logger.info('Running the baselines.')
102 |
103 | for baseline, (Model, hyper_parameter_space) in tqdm.tqdm(
104 | baselines.corpus.SHALLOW_BASELINES.items(),
105 | **settings.TQDM_KWARGS
106 | ):
107 | # tune the hyper-parameters and train the model
108 | ids, features, labels, label_scores = dataset.train
109 | if hyper_parameter_space:
110 | model = BayesSearchCV(
111 | Model,
112 | hyper_parameter_space,
113 | scoring=make_scorer(
114 | score_func=METRICS[metric][1],
115 | **METRICS[metric][2]),
116 | n_iter=n_iter,
117 | n_points=n_points,
118 | cv=n_folds,
119 | n_jobs=os.cpu_count() if n_jobs == 0 else n_jobs,
120 | refit=True)
121 | else:
122 | model = Model
123 | model.fit(features, labels)
124 |
125 | # Step 4: Save the model.
126 |
127 | with open(model_paths[baseline], 'wb') as model_file:
128 | dill.dump(model, model_file)
129 |
130 | # Step 5: Run evaluation on the splits.
131 |
132 | for split in splits:
133 | ids, features, labels, label_scores = getattr(dataset, split)
134 |
135 | predictions = model.predict(features)
136 | probabilities = model.predict_proba(features)
137 |
138 | with open(metrics_paths[baseline][split], 'w') as metrics_file:
139 | json.dump(
140 | {
141 | key: metric(
142 | y_true=labels,
143 | y_pred=probabilities
144 | if scorer_kwargs['needs_proba']
145 | else predictions)
146 | for key, (_, metric, scorer_kwargs) in METRICS.items()
147 | },
148 | metrics_file)
149 |
150 | with open(predictions_paths[baseline][split], 'w')\
151 | as predictions_file:
152 | for id_, probs, prediction in zip(
153 | ids, probabilities, predictions
154 | ):
155 | predictions_file.write(
156 | json.dumps({
157 | 'id': id_,
158 | 'label': prediction,
159 | 'label_scores': {
160 | class_: prob.tolist()
161 | for class_, prob
162 | in zip(model.classes_, probs)
163 | }
164 | }) + '\n')
165 |
--------------------------------------------------------------------------------
/src/scruples/scripts/evaluate/resource/__init__.py:
--------------------------------------------------------------------------------
1 | """Scripts for evaluating models on the scruples resource."""
2 |
3 | import logging
4 |
5 | import click
6 |
7 | from . import (
8 | predict_lm,
9 | run_shallow,
10 | tune_lm)
11 |
12 |
13 | logger = logging.getLogger(__name__)
14 |
15 |
16 | # main function
17 |
18 | @click.group()
19 | def resource():
20 | """Evaluate baseline models on the scruples resource."""
21 | pass
22 |
23 |
24 | # register subcommands to the command group
25 |
26 | subcommands = [
27 | predict_lm.predict_lm,
28 | run_shallow.run_shallow,
29 | tune_lm.tune_lm
30 | ]
31 |
32 | for subcommand in subcommands:
33 | resource.add_command(subcommand)
34 |
--------------------------------------------------------------------------------
/src/scruples/scripts/evaluate/resource/run_shallow.py:
--------------------------------------------------------------------------------
1 | """Run shallow baseline models on the scruples resource."""
2 |
3 | import collections
4 | import json
5 | import logging
6 | import os
7 | from typing import List
8 |
9 | import click
10 | import dill
11 | from sklearn.metrics import make_scorer
12 | from skopt import BayesSearchCV
13 | import tqdm
14 |
15 | from .... import settings, baselines
16 | from ....baselines.metrics import METRICS
17 | from ....dataset.readers import ScruplesResource
18 |
19 |
20 | logger = logging.getLogger(__name__)
21 |
22 |
23 | # main function
24 |
25 | @click.command()
26 | @click.argument(
27 | 'data_dir',
28 | type=click.Path(exists=True, file_okay=False, dir_okay=True))
29 | @click.argument(
30 | 'results_dir',
31 | type=click.Path(exists=False, file_okay=False, dir_okay=True))
32 | @click.argument(
33 | 'splits', type=click.Choice(ScruplesResource.SPLITS), nargs=-1)
34 | @click.option(
35 | '--metric',
36 | type=click.Choice(METRICS.keys()),
37 | default='log_loss',
38 | help='The metric to use for hyper-parameter tuning. Defaults to'
39 | ' log_loss.')
40 | @click.option(
41 | '--n-iter', type=int, default=128,
42 | help='The number of iterations of Bayesian optimization to run when'
43 | ' tuning baseline hyper-parameters. Defaults to 128.')
44 | @click.option(
45 | '--n-points', type=int, default=8,
46 | help='The number of points to evaluate in parallel during Bayesian'
47 | ' optimization. Defaults to 8.')
48 | @click.option(
49 | '--n-folds', type=int, default=4,
50 | help='The number of cross-validation folds to use. Defaults to 4.')
51 | @click.option(
52 | '--n-jobs', type=int, default=0,
53 | help='The number of parallel processes to use for tuning'
54 | ' hyper-parameters. At most n_folds * n_points processses can'
55 | ' be used at a given time. If 0, then the same number of'
56 | ' processes as CPUs will be used. Defaults to 0.')
57 | def run_shallow(
58 | data_dir: str,
59 | results_dir: str,
60 | splits: List[str],
61 | metric: str,
62 | n_iter: int,
63 | n_points: int,
64 | n_folds: int,
65 | n_jobs: int
66 | ) -> None:
67 | """Evaluate shallow baselines on the scruples resource.
68 |
69 | Train shallow baseline models on the scruples resource, reading
70 | the dataset from DATA_DIR, and writing trained models, logs, and
71 | other results to RESULTS_DIR. Performance is reported for each split
72 | provided as an argument.
73 | """
74 | # Step 1: Manage and construct paths.
75 |
76 | logger.info('Creating the results directory.')
77 |
78 | os.makedirs(results_dir)
79 | model_paths = {}
80 | metrics_paths = collections.defaultdict(dict)
81 | predictions_paths = collections.defaultdict(dict)
82 | for baseline in baselines.resource.SHALLOW_BASELINES.keys():
83 | os.makedirs(os.path.join(results_dir, baseline))
84 | model_paths[baseline] = os.path.join(
85 | results_dir, baseline, 'model.pkl')
86 | for split in splits:
87 | os.makedirs(os.path.join(results_dir, baseline, split))
88 | metrics_paths[baseline][split] = os.path.join(
89 | results_dir, baseline, split, 'metrics.json')
90 | predictions_paths[baseline][split] = os.path.join(
91 | results_dir, baseline, split, 'predictions.jsonl')
92 |
93 | # Step 2: Load the data.
94 |
95 | logger.info(f'Loading the data from {data_dir}.')
96 |
97 | dataset = ScruplesResource(data_dir=data_dir)
98 |
99 | # Step 3: Run the baselines.
100 |
101 | logger.info('Running the baselines.')
102 |
103 | for baseline, (Model, hyper_parameter_space) in tqdm.tqdm(
104 | baselines.resource.SHALLOW_BASELINES.items(),
105 | **settings.TQDM_KWARGS
106 | ):
107 | # tune the hyper-parameters and train the model
108 | ids, features, labels, label_scores = dataset.train
109 | if hyper_parameter_space:
110 | model = BayesSearchCV(
111 | Model,
112 | hyper_parameter_space,
113 | scoring=make_scorer(
114 | score_func=METRICS[metric][1],
115 | **METRICS[metric][2]),
116 | n_iter=n_iter,
117 | n_points=n_points,
118 | cv=n_folds,
119 | n_jobs=os.cpu_count() if n_jobs == 0 else n_jobs,
120 | refit=True)
121 | else:
122 | model = Model
123 | model.fit(features, labels)
124 |
125 | # Step 4: Save the model.
126 |
127 | with open(model_paths[baseline], 'wb') as model_file:
128 | dill.dump(model, model_file)
129 |
130 | # Step 5: Run evaluation on the splits.
131 |
132 | for split in splits:
133 | ids, features, labels, label_scores = getattr(dataset, split)
134 |
135 | predictions = model.predict(features)
136 | probabilities = model.predict_proba(features)
137 |
138 | with open(metrics_paths[baseline][split], 'w') as metrics_file:
139 | json.dump(
140 | {
141 | key: metric(
142 | y_true=labels,
143 | y_pred=probabilities
144 | if scorer_kwargs['needs_proba']
145 | else predictions)
146 | for key, (_, metric, scorer_kwargs) in METRICS.items()
147 | },
148 | metrics_file)
149 |
150 | with open(predictions_paths[baseline][split], 'w')\
151 | as predictions_file:
152 | for id_, probs, prediction in zip(
153 | ids, probabilities, predictions
154 | ):
155 | predictions_file.write(
156 | json.dumps({
157 | 'id': id_,
158 | 'label': prediction.tolist(),
159 | 'label_scores': probs.tolist()
160 | }) + '\n')
161 |
--------------------------------------------------------------------------------
/src/scruples/scripts/make/__init__.py:
--------------------------------------------------------------------------------
1 | """Scripts for making the different components of scruples."""
2 |
3 | import logging
4 |
5 | import click
6 |
7 | from . import (
8 | resource,
9 | corpus)
10 |
11 |
12 | logger = logging.getLogger(__name__)
13 |
14 |
15 | # main function
16 |
17 | @click.group()
18 | def make():
19 | """Make different components of scruples."""
20 | pass
21 |
22 |
23 | subcommands = [
24 | resource.resource,
25 | corpus.corpus
26 | ]
27 |
28 | for subcommand in subcommands:
29 | make.add_command(subcommand)
30 |
--------------------------------------------------------------------------------
/src/scruples/scripts/make/corpus.py:
--------------------------------------------------------------------------------
1 | """Make the scruples corpus from raw reddit data.
2 |
3 | This script takes in posts and comments from the reddit API and creates
4 | the scruples dataset.
5 | """
6 |
7 | import collections
8 | import json
9 | import logging
10 | import os
11 | import random
12 |
13 | import attr
14 | import click
15 | import tqdm
16 |
17 | from ... import settings, utils
18 | from ...data.comment import Comment
19 | from ...data.post import Post
20 | from ...data.utils import instantiate_attrs_with_extra_kwargs
21 |
22 |
23 | logger = logging.getLogger(__name__)
24 |
25 |
26 | # main function
27 |
28 | @click.command()
29 | @click.argument(
30 | 'comments_path',
31 | type=click.Path(exists=True, file_okay=True, dir_okay=False))
32 | @click.argument(
33 | 'posts_path',
34 | type=click.Path(exists=True, file_okay=True, dir_okay=False))
35 | @click.argument(
36 | 'corpus_dir',
37 | type=click.Path(exists=False, file_okay=False, dir_okay=True))
38 | def corpus(
39 | comments_path: str,
40 | posts_path: str,
41 | corpus_dir: str
42 | ) -> None:
43 | """Create the scruples corpus and write it to CORPUS_DIR.
44 |
45 | Read in the reddit posts from POSTS_PATH and comments from
46 | COMMENTS_PATH, create the scruples corpus, and write it to
47 | CORPUS_DIR.
48 | """
49 | # Create the output directory.
50 | os.makedirs(corpus_dir)
51 |
52 | # Step 1: Read in the comments and index them by their link ids.
53 | logger.info('Reading the comments.')
54 |
55 | link_id_to_comments = collections.defaultdict(list)
56 | with click.open_file(comments_path, 'r') as comments_file:
57 | for ln in tqdm.tqdm(comments_file.readlines(), **settings.TQDM_KWARGS):
58 | comment = instantiate_attrs_with_extra_kwargs(
59 | Comment,
60 | **json.loads(ln))
61 |
62 | # IDs are usually prefixed with something like "t1_",
63 | # "t2_", etc. to denote what type of object it is. Slice
64 | # off the first 3 characters to remove this prefix from
65 | # the link id because it will not be on the posts' IDs
66 | # when we join the comments to them.
67 | link_id_to_comments[comment.link_id[3:]].append(comment)
68 |
69 | # Step 2: Read in the posts and join them with their comments.
70 | logger.info('Reading the posts.')
71 |
72 | posts = []
73 | with click.open_file(posts_path, 'r') as posts_file:
74 | for ln in tqdm.tqdm(posts_file.readlines(), **settings.TQDM_KWARGS):
75 | kwargs = json.loads(ln)
76 | post = instantiate_attrs_with_extra_kwargs(
77 | Post,
78 | comments=link_id_to_comments[kwargs['id']],
79 | **kwargs)
80 |
81 | posts.append(post)
82 |
83 | # Step 3: Write the posts to disk.
84 | logger.info('Writing the posts to disk.')
85 |
86 | processed_posts_path = os.path.join(corpus_dir, settings.POSTS_FILENAME)
87 | with open(processed_posts_path, 'w') as processed_posts_file:
88 | for post in posts:
89 | processed_posts_file.write(json.dumps(attr.asdict(post)) + '\n')
90 |
91 | # Step 4: Filter out bad posts.
92 | logger.info('Filtering out bad posts.')
93 |
94 | dataset_posts = [
95 | post
96 | for post in tqdm.tqdm(posts, **settings.TQDM_KWARGS)
97 | if post.is_good
98 | ]
99 |
100 | # Step 5: Create the splits then write them to disk.
101 | logger.info('Creating splits and writing them to disk.')
102 |
103 | # Shuffle dataset_posts so that the splits will be random.
104 | random.shuffle(dataset_posts)
105 |
106 | if [split['size'] for split in settings.SPLITS].count(None) > 1:
107 | raise ValueError(
108 | 'The settings.SPLITS constant should have at most ONE split'
109 | ' with a size of None.')
110 |
111 | # Make sure that the split with a size of ``None`` will be processed
112 | # last.
113 | splits = [
114 | split
115 | for split in settings.SPLITS
116 | if split['size'] is not None
117 | ] + [
118 | split
119 | for split in settings.SPLITS
120 | if split['size'] is None
121 | ]
122 | for split in splits:
123 | split_path = os.path.join(
124 | corpus_dir,
125 | settings.CORPUS_FILENAME_TEMPLATE.format(split=split['name']))
126 | with open(split_path, 'w') as split_file:
127 | if split['size'] is None:
128 | split_posts = dataset_posts
129 | dataset_posts = []
130 | else:
131 | split_posts = dataset_posts[:split['size']]
132 | dataset_posts = dataset_posts[split['size']:]
133 | for post in tqdm.tqdm(split_posts, **settings.TQDM_KWARGS):
134 | instance = {
135 | 'id': utils.make_id(),
136 | 'post_id': post.id,
137 | 'action':
138 | attr.asdict(post.action)
139 | if post.action is not None else
140 | None,
141 | 'title': post.title,
142 | 'text': post.original_text,
143 | 'post_type': post.post_type.name,
144 | 'label_scores': {
145 | label.name: score
146 | for label, score
147 | in post.label_scores.label_to_score.items()
148 | },
149 | 'label': post.label_scores.best_label.name,
150 | 'binarized_label_scores': {
151 | binarized_label.name: score
152 | for binarized_label, score
153 | in post.label_scores.binarized_label_to_score.items()
154 | },
155 | 'binarized_label': post.label_scores\
156 | .best_binarized_label.name
157 | }
158 |
159 | split_file.write(json.dumps(instance) + '\n')
160 |
--------------------------------------------------------------------------------
/src/scruples/scripts/make/resource/__init__.py:
--------------------------------------------------------------------------------
1 | """Scripts for making the scruples resource."""
2 |
3 | import logging
4 |
5 | import click
6 |
7 | from . import (
8 | dataset,
9 | proposals)
10 |
11 |
12 | logger = logging.getLogger(__name__)
13 |
14 |
15 | # main function
16 |
17 | @click.group()
18 | def resource():
19 | """Make different components of the scruples resource."""
20 | pass
21 |
22 |
23 | subcommands = [
24 | dataset.dataset,
25 | proposals.proposals
26 | ]
27 |
28 | for subcommand in subcommands:
29 | resource.add_command(subcommand)
30 |
--------------------------------------------------------------------------------
/src/scruples/scripts/make/resource/proposals.py:
--------------------------------------------------------------------------------
1 | """Create candidate instances for the scruples resource.
2 |
3 | This script takes in the scruples corpus and creates candidate
4 | instances for the scruples resource of ranked action pairs. The
5 | proposals then must be annotated using Mechanical Turk.
6 | """
7 |
8 | import json
9 | import logging
10 | import os
11 | import random
12 |
13 | import click
14 | import tqdm
15 |
16 | from .... import settings, utils
17 | from ....data.action import Action
18 |
19 |
20 | logger = logging.getLogger(__name__)
21 |
22 |
23 | # main function
24 |
25 | @click.command()
26 | @click.argument(
27 | 'corpus_dir',
28 | type=click.Path(exists=True, file_okay=False, dir_okay=True))
29 | @click.argument(
30 | 'proposals_dir',
31 | type=click.Path(exists=False, file_okay=False, dir_okay=True))
32 | @click.option(
33 | '--rounds', type=int, default=3,
34 | help='The number of rounds of random matchings to run. Each round'
35 | ' uses every action from the resource once. So, k rounds will'
36 | ' use an action at most k times. Duplicate action pairs are'
37 | ' dropped after each round. Defaults to 3.')
38 | def proposals(
39 | corpus_dir: str,
40 | proposals_dir: str,
41 | rounds: int
42 | ) -> None:
43 | """Propose and write instances for the resource to PROPOSALS_DIR.
44 |
45 | Read in the scruples corpus from CORPUS_DIR and then for each
46 | split in the corpus, randomly pair actions together and write them
47 | to PROPOSALS_DIR in a format ready for being annotated on Mechanical
48 | Turk.
49 | """
50 | # Create the output directory
51 | os.makedirs(proposals_dir)
52 |
53 | # Iterate over each split of the corpus, creating the resource data
54 | # for each.
55 | for split in settings.SPLITS:
56 | logger.info(f'Reading the {split["name"]} split from the corpus.')
57 |
58 | ids_actions = []
59 | corpus_split_path = os.path.join(
60 | corpus_dir,
61 | settings.CORPUS_FILENAME_TEMPLATE.format(split=split['name']))
62 | with open(corpus_split_path, 'r') as corpus_split_file:
63 | for ln in tqdm.tqdm(corpus_split_file, **settings.TQDM_KWARGS):
64 | row = json.loads(ln)
65 | if row['action'] is not None:
66 | action = Action(**row['action'])
67 | if action.is_good:
68 | ids_actions.append((row['id'], action))
69 |
70 | logger.info(f'Computing random matchings for {rounds} rounds.')
71 |
72 | seen_pairs = set()
73 | instances = []
74 | for _ in range(rounds):
75 | random.shuffle(ids_actions)
76 | for i in range(0, len(ids_actions) - 1, 2):
77 | # check if the pair is a duplicate
78 |
79 | id0, _ = ids_actions[i]
80 | id1, _ = ids_actions[i+1]
81 | # canonicalize the order of the pair
82 | id0, id1 = sorted([id0, id1])
83 | # check if the pair has been seen before
84 | if (id0, id1) in seen_pairs:
85 | continue
86 | else:
87 | seen_pairs.add((id0, id1))
88 |
89 | # add the instance to the proposals
90 |
91 | instances.append({
92 | 'id': utils.make_id(),
93 | 'actions': [
94 | {
95 | 'id': id_,
96 | 'description': action.description
97 | }
98 | for id_, action in ids_actions[i:i+2]
99 | ]
100 | })
101 |
102 | logger.info(f'Writing proposals to {proposals_dir}.')
103 |
104 | proposals_split_path = os.path.join(
105 | proposals_dir,
106 | settings.PROPOSALS_FILENAME_TEMPLATE.format(split=split['name']))
107 | with open(proposals_split_path, 'w') as proposals_split_file:
108 | for i in range(0, len(instances), settings.N_INSTANCES_PER_HIT):
109 | proposals_split_file.write(json.dumps({
110 | 'instances': instances[i:i+settings.N_INSTANCES_PER_HIT]
111 | }) + '\n')
112 |
--------------------------------------------------------------------------------
/src/scruples/settings.py:
--------------------------------------------------------------------------------
1 | """Constants and settings."""
2 |
3 | import os
4 |
5 |
6 | # subreddit related constants
7 |
8 | AUTO_MODERATOR_NAME = 'AutoModerator'
9 | """The name of the AutoModerator bot."""
10 |
11 |
12 | # dataset parameters
13 |
14 | # splits
15 |
16 | SPLITS = [
17 | {
18 | 'name': 'train',
19 | 'size': None
20 | },
21 | {
22 | 'name': 'dev',
23 | 'size': 2500
24 | },
25 | {
26 | 'name': 'test',
27 | 'size': 2500
28 | }
29 | ]
30 | """Definitions for the various dataset splits.
31 |
32 | A ``'size'`` of ``None`` signifies that the all the data that is not
33 | used in the other splits should go to that split.
34 | """
35 | # N.B. This variable is the single-source-of-truth for the splits, their
36 | # names and their sizes (in terms of the number of posts used to create
37 | # them). If this variable is modified, make sure to update the doc
38 | # strings for ``ScruplesCorpus``, ``ScruplesCorpusDataset``.
39 | #
40 | # Note that the resource splits will have a different number of
41 | # instances than the number of posts used to create them, in
42 | # general. So, the 'size' attribute doesn't necessarily give the size of
43 | # the resource splits.
44 |
45 | # corpus settings
46 |
47 | CORPUS_FILENAME_TEMPLATE = '{split}.scruples-corpus.jsonl'
48 | """A template string for the corpus's split filenames."""
49 |
50 | POSTS_FILENAME = 'all.scruples-posts.jsonl'
51 | """The filename for the file containing all the posts."""
52 |
53 |
54 | # resource settings
55 |
56 | PROPOSALS_FILENAME_TEMPLATE = '{split}.scruples-proposals.jsonl'
57 | """A template string for the resource proposals' split filenames."""
58 |
59 | RESOURCE_FILENAME_TEMPLATE = '{split}.scruples-actions.jsonl'
60 | """A template string for the resource's split filenames."""
61 |
62 | N_ANNOTATORS_FOR_GOLD_LABELS = 5
63 | """The number of annotators to use for creating the gold labels."""
64 |
65 | MIN_AGREEMENT = 5
66 | """The minimum number of gold annotators required to agree.
67 |
68 | The minimum number of gold annotators required to agree for the instance
69 | to be considered non-controversial.
70 | """
71 |
72 | N_ANNOTATORS_FOR_HUMAN_PERFORMANCE = 5
73 | """The number of annotators to use for evaluating human performance."""
74 |
75 | N_INSTANCES_PER_HIT = 20
76 | """The number of instances annotated in a single HIT."""
77 |
78 |
79 | # evaluation
80 |
81 | LOSS_TYPES = [
82 | 'xentropy-hard',
83 | 'xentropy-soft',
84 | 'xentropy-full',
85 | 'dirichlet-multinomial'
86 | ]
87 | """The different loss types for deep baseline models.
88 |
89 | ``"xentropy-hard"`` uses cross-entropy on the hard labels derived from
90 | the plurality.
91 |
92 | ``"xentropy-soft"`` uses cross-entropy against soft labels derived from
93 | averaging the individual labels together.
94 |
95 | ``"xentropy-full"`` uses the full negative log-likelihood objective with
96 | all of the annotations. So, unlike ``"xentropy-soft"``, it doesn't
97 | average the annotations then compute the cross-entropy, but simply sums
98 | the contributions from each label. This loss is equivalent to
99 | ``"xentropy-soft"`` in the case where each instance has the same number
100 | of annotations.
101 |
102 | ``"dirichlet-multinomial"`` uses a dirichlet-multinomial likelihood
103 | where the model is predicting the parameters for the dirichlet
104 | distribution as part of the hierarchical model.
105 | """
106 |
107 |
108 | # demos
109 |
110 | def _coerce_if_not_none(value, type_):
111 | if value is None:
112 | return None
113 |
114 | return type_(value)
115 |
116 |
117 | NORMS_ACTIONS_BASELINE = _coerce_if_not_none(
118 | os.environ.get('SCRUPLES_NORMS_ACTIONS_BASELINE'),
119 | str)
120 | """The baseline to use for predicting the actions in the norms demo.
121 |
122 | This constant should be one of the keys from
123 | ``scruples.baselines.resource.FINE_TUNE_LM_BASELINES`` and should correspond to
124 | the ``NORMS_ACTIONS_MODEL`` setting.
125 | """
126 |
127 | NORMS_ACTIONS_MODEL = _coerce_if_not_none(
128 | os.environ.get('SCRUPLES_NORMS_ACTIONS_MODEL'),
129 | str)
130 | """The path to the model directory to use for predicting the actions.
131 |
132 | The path to the fine-tuned Dirichlet-multinomial likelihood model to use when
133 | predicting the actions in the ``norms`` demo. The chosen directory should be
134 | the result of calling ``.save_pretrained`` on the model instance. See the
135 | transformers_ library for more details.
136 |
137 | .. _transformers: https://github.com/huggingface/transformers
138 | """
139 |
140 | NORMS_CORPUS_BASELINE = _coerce_if_not_none(
141 | os.environ.get('SCRUPLES_NORMS_CORPUS_BASELINE'),
142 | str)
143 | """The baseline to use for predicting the corpus in the norms demo.
144 |
145 | This constant should be one of the keys from
146 | ``scruples.baselines.corpus.FINE_TUNE_LM_BASELINES`` and should correspond to
147 | the ``NORMS_CORPUS_MODEL`` setting.
148 | """
149 |
150 | NORMS_CORPUS_MODEL = _coerce_if_not_none(
151 | os.environ.get('SCRUPLES_NORMS_CORPUS_MODEL'),
152 | str)
153 | """The path to the model directory to use for predicting the corpus.
154 |
155 | The path to the fine-tuned Dirichlet-multinomial likelihood model to use when
156 | predicting the corpus in the ``norms`` demo. The chosen directory should be the
157 | result of calling ``.save_pretrained`` on the model instance. See the
158 | transformers_ library for more details.
159 | """
160 |
161 | NORMS_PREDICT_BATCH_SIZE = _coerce_if_not_none(
162 | os.environ.get('SCRUPLES_NORMS_PREDICT_BATCH_SIZE'),
163 | int)
164 | """The batch size to use for predictions in the ``norms`` demo."""
165 |
166 | NORMS_GPU_IDS = _coerce_if_not_none(
167 | os.environ.get('SCRUPLES_NORMS_GPU_IDS'),
168 | str)
169 | """The GPU IDs to use for making predictions.
170 |
171 | The GPU IDs to use when making predictions in the ``norms`` demo. In the
172 | environment variable specifying this configuration, the GPU IDs should be
173 | separated by commas (i.e., ``"0,1,2"``).
174 | """
175 |
176 |
177 | # output and logging
178 |
179 | LOG_FORMAT = '%(asctime)s:%(levelname)s:%(name)s: %(message)s'
180 | """The format string for logging."""
181 |
182 | TQDM_KWARGS = {
183 | 'ncols': 72,
184 | 'leave': False
185 | }
186 | """Key-word arguments for tqdm progress bars."""
187 |
--------------------------------------------------------------------------------
/src/scruples/vendor/__init__.py:
--------------------------------------------------------------------------------
1 | """Vendored code from various projects."""
2 |
3 | from . import skopt
4 |
--------------------------------------------------------------------------------
/src/scruples/vendor/skopt.py:
--------------------------------------------------------------------------------
1 | # New BSD License
2 | #
3 | # Copyright (c) 2016 - scikit-optimize developers.
4 | #
5 | # All rights reserved.
6 | #
7 | # Redistribution and use in source and binary forms, with or without
8 | # modification, are permitted provided that the following conditions are
9 | # met:
10 | #
11 | # a. Redistributions of source code must retain the above copyright
12 | # notice, this list of conditions and the following disclaimer.
13 | #
14 | # b. Redistributions in binary form must reproduce the above copyright
15 | # notice, this list of conditions and the following disclaimer in the
16 | # documentation and/or other materials provided with the distribution.
17 | #
18 | # c. Neither the name of the scikit-optimize developers nor the names of
19 | # its contributors may be used to endorse or promote products derived from
20 | # this software without specific prior written permission.
21 | #
22 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS
23 | # IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
24 | # TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
25 | # PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR
26 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
27 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
28 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
29 | # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
30 | # LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
31 | # NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
32 | # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
33 |
34 | """Code vendored from skopt (scikit-optimize).
35 |
36 | See https://github.com/scikit-optimize/scikit-optimize for the source.
37 | """
38 |
39 | from skopt.utils import dump
40 |
41 |
42 | # vendored from commit: 6740876a6f9ad92c732d394e8534a5236a8d3f84
43 | class CheckpointSaver(object):
44 | """
45 | Save current state after each iteration with `skopt.dump`.
46 |
47 |
48 | Example usage:
49 | import skopt
50 |
51 | checkpoint_callback = skopt.callbacks.CheckpointSaver("./result.pkl")
52 | skopt.gp_minimize(obj_fun, dims, callback=[checkpoint_callback])
53 |
54 | Parameters
55 | ----------
56 | * `checkpoint_path`: location where checkpoint will be saved to;
57 | * `dump_options`: options to pass on to `skopt.dump`, like `compress=9`
58 | """
59 | def __init__(self, checkpoint_path, **dump_options):
60 | self.checkpoint_path = checkpoint_path
61 | self.dump_options = dump_options
62 |
63 | def __call__(self, res):
64 | """
65 | Parameters
66 | ----------
67 | * `res` [`OptimizeResult`, scipy object]:
68 | The optimization as a OptimizeResult object.
69 | """
70 | dump(res, self.checkpoint_path, **self.dump_options)
71 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
1 | """Tests for the scruples repository."""
2 |
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 | """Configuration for running tests."""
2 |
3 | import pytest
4 |
5 |
6 | def pytest_addoption(parser):
7 | parser.addoption(
8 | "--skip-slow",
9 | action="store_true",
10 | help="Skip slow tests, useful for quick checks.")
11 |
12 |
13 | def pytest_collection_modifyitems(config, items):
14 | if config.getoption("--skip-slow"):
15 | skip_slow = pytest.mark.skip(
16 | reason="--skip-slow option is turned on.")
17 | for item in items:
18 | if "slow" in item.keywords:
19 | item.add_marker(skip_slow)
20 |
--------------------------------------------------------------------------------
/tests/fixtures/corpus-easy/dev.scruples-corpus.jsonl:
--------------------------------------------------------------------------------
1 | {"id": "id_20", "post_type": "HISTORICAL", "title": "AITA test post", "text": "Label this post nta.", "label_scores": {"AUTHOR": 0, "OTHER": 10, "EVERYBODY": 0, "NOBODY": 0, "INFO": 0}, "label": "OTHER"}
2 | {"id": "id_21", "post_type": "HISTORICAL", "title": "AITA test post", "text": "Label this post yta.", "label_scores": {"AUTHOR": 10, "OTHER": 0, "EVERYBODY": 0, "NOBODY": 0, "INFO": 0}, "label": "AUTHOR"}
3 | {"id": "id_22", "post_type": "HISTORICAL", "title": "AITA test post", "text": "Label this post esh.", "label_scores": {"AUTHOR": 0, "OTHER": 0, "EVERYBODY": 10, "NOBODY": 0, "INFO": 0}, "label": "EVERYBODY"}
4 | {"id": "id_23", "post_type": "HISTORICAL", "title": "AITA test post", "text": "Label this post nah.", "label_scores": {"AUTHOR": 0, "OTHER": 0, "EVERYBODY": 0, "NOBODY": 10, "INFO": 0}, "label": "NOBODY"}
5 | {"id": "id_24", "post_type": "HISTORICAL", "title": "AITA test post", "text": "Label this post info.", "label_scores": {"AUTHOR": 0, "OTHER": 0, "EVERYBODY": 0, "NOBODY": 0, "INFO": 10}, "label": "INFO"}
6 |
--------------------------------------------------------------------------------
/tests/fixtures/corpus-easy/test.scruples-corpus.jsonl:
--------------------------------------------------------------------------------
1 | {"id": "id_25", "post_type": "HISTORICAL", "title": "AITA test post", "text": "The label for this post should be nta.", "label_scores": {"AUTHOR": 0, "OTHER": 10, "EVERYBODY": 0, "NOBODY": 0, "INFO": 0}, "label": "OTHER"}
2 | {"id": "id_26", "post_type": "HISTORICAL", "title": "AITA test post", "text": "The label for this post should be yta.", "label_scores": {"AUTHOR": 10, "OTHER": 0, "EVERYBODY": 0, "NOBODY": 0, "INFO": 0}, "label": "AUTHOR"}
3 | {"id": "id_27", "post_type": "HISTORICAL", "title": "AITA test post", "text": "The label for this post should be esh.", "label_scores": {"AUTHOR": 0, "OTHER": 0, "EVERYBODY": 10, "NOBODY": 0, "INFO": 0}, "label": "EVERYBODY"}
4 | {"id": "id_28", "post_type": "HISTORICAL", "title": "AITA test post", "text": "The label for this post should be nah.", "label_scores": {"AUTHOR": 0, "OTHER": 0, "EVERYBODY": 0, "NOBODY": 10, "INFO": 0}, "label": "NOBODY"}
5 | {"id": "id_29", "post_type": "HISTORICAL", "title": "AITA test post", "text": "The label for this post should be info.", "label_scores": {"AUTHOR": 0, "OTHER": 0, "EVERYBODY": 0, "NOBODY": 0, "INFO": 10}, "label": "INFO"}
6 |
--------------------------------------------------------------------------------
/tests/fixtures/corpus-easy/train.scruples-corpus.jsonl:
--------------------------------------------------------------------------------
1 | {"id": "id_0", "post_type": "HISTORICAL", "title": "AITA test post", "text": "This post is nta.", "label_scores": {"AUTHOR": 0, "OTHER": 10, "EVERYBODY": 0, "NOBODY": 0, "INFO": 0}, "label": "OTHER"}
2 | {"id": "id_1", "post_type": "HISTORICAL", "title": "AITA test post", "text": "This post is yta.", "label_scores": {"AUTHOR": 10, "OTHER": 0, "EVERYBODY": 0, "NOBODY": 0, "INFO": 0}, "label": "AUTHOR"}
3 | {"id": "id_2", "post_type": "HISTORICAL", "title": "AITA test post", "text": "This post is esh.", "label_scores": {"AUTHOR": 0, "OTHER": 0, "EVERYBODY": 10, "NOBODY": 0, "INFO": 0}, "label": "EVERYBODY"}
4 | {"id": "id_3", "post_type": "HISTORICAL", "title": "AITA test post", "text": "This post is nah.", "label_scores": {"AUTHOR": 0, "OTHER": 0, "EVERYBODY": 0, "NOBODY": 10, "INFO": 0}, "label": "NOBODY"}
5 | {"id": "id_4", "post_type": "HISTORICAL", "title": "AITA test post", "text": "This post is info.", "label_scores": {"AUTHOR": 0, "OTHER": 0, "EVERYBODY": 0, "NOBODY": 0, "INFO": 10}, "label": "INFO"}
6 | {"id": "id_5", "post_type": "HISTORICAL", "title": "AITA test post", "text": "Another post that is nta.", "label_scores": {"AUTHOR": 0, "OTHER": 10, "EVERYBODY": 0, "NOBODY": 0, "INFO": 0}, "label": "OTHER"}
7 | {"id": "id_6", "post_type": "HISTORICAL", "title": "AITA test post", "text": "Another post that is yta.", "label_scores": {"AUTHOR": 10, "OTHER": 0, "EVERYBODY": 0, "NOBODY": 0, "INFO": 0}, "label": "AUTHOR"}
8 | {"id": "id_7", "post_type": "HISTORICAL", "title": "AITA test post", "text": "Another post that is esh.", "label_scores": {"AUTHOR": 0, "OTHER": 0, "EVERYBODY": 10, "NOBODY": 0, "INFO": 0}, "label": "EVERYBODY"}
9 | {"id": "id_8", "post_type": "HISTORICAL", "title": "AITA test post", "text": "Another post that is nah.", "label_scores": {"AUTHOR": 0, "OTHER": 0, "EVERYBODY": 0, "NOBODY": 10, "INFO": 0}, "label": "NOBODY"}
10 | {"id": "id_9", "post_type": "HISTORICAL", "title": "AITA test post", "text": "Another post that is info.", "label_scores": {"AUTHOR": 0, "OTHER": 0, "EVERYBODY": 0, "NOBODY": 0, "INFO": 10}, "label": "INFO"}
11 | {"id": "id_10", "post_type": "HISTORICAL", "title": "AITA test post", "text": "This post's label is nta.", "label_scores": {"AUTHOR": 0, "OTHER": 10, "EVERYBODY": 0, "NOBODY": 0, "INFO": 0}, "label": "OTHER"}
12 | {"id": "id_11", "post_type": "HISTORICAL", "title": "AITA test post", "text": "This post's label is yta.", "label_scores": {"AUTHOR": 10, "OTHER": 0, "EVERYBODY": 0, "NOBODY": 0, "INFO": 0}, "label": "AUTHOR"}
13 | {"id": "id_12", "post_type": "HISTORICAL", "title": "AITA test post", "text": "This post's label is esh.", "label_scores": {"AUTHOR": 0, "OTHER": 0, "EVERYBODY": 10, "NOBODY": 0, "INFO": 0}, "label": "EVERYBODY"}
14 | {"id": "id_13", "post_type": "HISTORICAL", "title": "AITA test post", "text": "This post's label is nah.", "label_scores": {"AUTHOR": 0, "OTHER": 0, "EVERYBODY": 0, "NOBODY": 10, "INFO": 0}, "label": "NOBODY"}
15 | {"id": "id_14", "post_type": "HISTORICAL", "title": "AITA test post", "text": "This post's label is info.", "label_scores": {"AUTHOR": 0, "OTHER": 0, "EVERYBODY": 0, "NOBODY": 0, "INFO": 10}, "label": "INFO"}
16 | {"id": "id_15", "post_type": "HISTORICAL", "title": "AITA test post", "text": "Predict nta.", "label_scores": {"AUTHOR": 0, "OTHER": 10, "EVERYBODY": 0, "NOBODY": 0, "INFO": 0}, "label": "OTHER"}
17 | {"id": "id_16", "post_type": "HISTORICAL", "title": "AITA test post", "text": "Predict yta.", "label_scores": {"AUTHOR": 10, "OTHER": 0, "EVERYBODY": 0, "NOBODY": 0, "INFO": 0}, "label": "AUTHOR"}
18 | {"id": "id_17", "post_type": "HISTORICAL", "title": "AITA test post", "text": "Predict esh.", "label_scores": {"AUTHOR": 0, "OTHER": 0, "EVERYBODY": 10, "NOBODY": 0, "INFO": 0}, "label": "EVERYBODY"}
19 | {"id": "id_18", "post_type": "HISTORICAL", "title": "AITA test post", "text": "Predict nah.", "label_scores": {"AUTHOR": 0, "OTHER": 0, "EVERYBODY": 0, "NOBODY": 10, "INFO": 0}, "label": "NOBODY"}
20 | {"id": "id_19", "post_type": "HISTORICAL", "title": "AITA test post", "text": "Predict info.", "label_scores": {"AUTHOR": 0, "OTHER": 0, "EVERYBODY": 0, "NOBODY": 0, "INFO": 10}, "label": "INFO"}
21 |
--------------------------------------------------------------------------------
/tests/fixtures/raw-data/test-comment.jsonl:
--------------------------------------------------------------------------------
1 | {"id":"aaaaaaa","subreddit_id":"t5_aaaaaa","subreddit":"AmItheAsshole","link_id":"t3_aaaaaa","parent_id":"t3_aaaaaa","created_utc":1540000000,"author":"example-user","body":"This comment is for testing.","score":10,"controversiality":0,"gilded":0}
2 |
--------------------------------------------------------------------------------
/tests/fixtures/raw-data/test-post-comments.jsonl:
--------------------------------------------------------------------------------
1 | {"id":"aaaaaaa","subreddit_id":"t5_aaaaaa","subreddit":"AmItheAsshole","link_id":"t3_aaaaaa","parent_id":"t3_aaaaaa","created_utc":1541000000,"author":"example-user","body":"NTA, definitely nta.","score":6,"controversiality":0,"gilded":0}
2 | {"id":"aaaaaab","subreddit_id":"t5_aaaaaa","subreddit":"AmItheAsshole","link_id":"t3_aaaaaa","parent_id":"t3_aaaaaa","created_utc":1542000000,"author":"example-user","body":"I think YTA.","score":7,"controversiality":0,"gilded":0}
3 | {"id":"aaaaaac","subreddit_id":"t5_aaaaaa","subreddit":"AmItheAsshole","link_id":"t3_aaaaaa","parent_id":"t3_aaaaaa","created_utc":1543000000,"author":"example-user","body":"I'm gonna have to go with NTA on this one.","score":8,"controversiality":0,"gilded":0}
4 | {"id":"aaaaaad","subreddit_id":"t5_aaaaaa","subreddit":"AmItheAsshole","link_id":"t3_aaaaaa","parent_id":"t3_aaaaaa","created_utc":1544000000,"author":"example-user","body":"Very surprising!","score":9,"controversiality":0,"gilded":0}
5 | {"id":"aaaaaae","subreddit_id":"t5_aaaaaa","subreddit":"AmItheAsshole","link_id":"t3_aaaaaa","parent_id":"t3_aaaaaa","created_utc":1545000000,"author":"AutoModerator","body":"^^^^AUTOMOD ***This is a copy of the above post. It is a record of the post as originally written, in case the post is deleted or edited.***\n\nThe original text.\n\n*I am a bot, and this action was performed automatically. Please [contact the moderators of this subreddit](/message/compose/?to=/r/AmItheAsshole) if you have any questions or concerns.*","score":10,"controversiality":0,"gilded":0}
6 |
--------------------------------------------------------------------------------
/tests/fixtures/raw-data/test-post.jsonl:
--------------------------------------------------------------------------------
1 | {"id":"aaaaaa","subreddit_id":"t5_aaaaa","subreddit":"AmItheAsshole","permalink":"/r/AmItheAsshole/comments/aaaaaa/aita_test_post/","domain":"self.AmItheAsshole","url":"https://www.reddit.com/r/AmItheAsshole/comments/aaaaaa/aita_test_post/","created_utc":1540000000,"author":"example-user","title":"AITA test post","selftext":"This post is for testing.","thumbnail":"default","score":10,"num_comments":10,"gilded":0,"retrieved_on":1545000000,"archived":false,"is_self":true,"over_18":false,"stickied":false}
2 |
--------------------------------------------------------------------------------
/tests/fixtures/resource-easy/dev.scruples-actions.jsonl:
--------------------------------------------------------------------------------
1 | {"id": "id_16", "actions": [{"id": "action_id_16_0", "description": "The good action."}, {"id": "action_id_16_1", "description": "The bad action."}], "gold_label": 0, "gold_annotations": [1, 0]}
2 | {"id": "id_17", "actions": [{"id": "action_id_17_0", "description": "This is bad!"}, {"id": "action_id_17_1", "description": "This is good!"}], "gold_label": 1, "gold_annotations": [0, 1]}
3 | {"id": "id_18", "actions": [{"id": "action_id_18_0", "description": "This is good!"}, {"id": "action_id_18_1", "description": "The bad action."}], "gold_label": 0, "gold_annotations": [1, 0]}
4 | {"id": "id_19", "actions": [{"id": "action_id_19_0", "description": "This is bad!"}, {"id": "action_id_19_0", "description": "The good action."}], "gold_label": 1, "gold_annotations": [0, 1]}
5 |
--------------------------------------------------------------------------------
/tests/fixtures/resource-easy/test.scruples-actions.jsonl:
--------------------------------------------------------------------------------
1 | {"id": "id_20", "actions": [{"id": "action_id_20_0", "description": "Indeed, this describes a very good action."}, {"id": "action_id_20_1", "description": "Indeed, this describes a very bad action."}], "gold_label": 0, "gold_annotations": [1, 0]}
2 | {"id": "id_21", "actions": [{"id": "action_id_21_0", "description": "Always try to do bad."}, {"id": "action_id_21_1", "description": "Always try to do good."}], "gold_label": 1, "gold_annotations": [0, 1]}
3 | {"id": "id_22", "actions": [{"id": "action_id_22_0", "description": "Always try to do good."}, {"id": "action_id_22_1", "description": "Indeed, this describes a very bad action."}], "gold_label": 0, "gold_annotations": [1, 0]}
4 | {"id": "id_23", "actions": [{"id": "action_id_23_0", "description": "Always try to do bad."}, {"id": "action_id_23_1", "description": "Indeed, this describes a very good action."}], "gold_label": 1, "gold_annotations": [0, 1]}
5 |
--------------------------------------------------------------------------------
/tests/fixtures/resource-easy/train.scruples-actions.jsonl:
--------------------------------------------------------------------------------
1 | {"id": "id_0", "actions": [{"id": "action_id_0_0", "description": "A good action."}, {"id": "action_id_0_1", "description": "A bad action."}], "gold_label": 0, "gold_annotations": [1, 0]}
2 | {"id": "id_1", "actions": [{"id": "action_id_1_0", "description": "A very bad action."}, {"id": "action_id_1_1", "description": "A very good action."}], "gold_label": 1, "gold_annotations": [0, 1]}
3 | {"id": "id_2", "actions": [{"id": "action_id_2_0", "description": "This action is good."}, {"id": "action_id_2_1", "description": "This action is bad."}], "gold_label": 0, "gold_annotations": [1, 0]}
4 | {"id": "id_3", "actions": [{"id": "action_id_3_0", "description": "bad!"}, {"id": "action_id_3_1", "description": "good!"}], "gold_label": 1, "gold_annotations": [0, 1]}
5 | {"id": "id_4", "actions": [{"id": "action_id_4_0", "description": "Very good."}, {"id": "action_id_4_1", "description": "Very bad."}], "gold_label": 0, "gold_annotations": [1, 0]}
6 | {"id": "id_5", "actions": [{"id": "action_id_5_0", "description": "A most bad thing to do."}, {"id": "action_id_5_1", "description": "A most good thing to do."}], "gold_label": 1, "gold_annotations": [0, 1]}
7 | {"id": "id_6", "actions": [{"id": "action_id_6_0", "description": "Do choose this good thing."}, {"id": "action_id_6_1", "description": "Don't choose this bad thing"}], "gold_label": 0, "gold_annotations": [1, 0]}
8 | {"id": "id_7", "actions": [{"id": "action_id_7_0", "description": "This description is about a bad action."}, {"id": "action_id_7_1", "description": "This description is about a good action."}], "gold_label": 1, "gold_annotations": [0, 1]}
9 | {"id": "id_8", "actions": [{"id": "action_id_8_0", "description": "A good action."}, {"id": "action_id_8_1", "description": "A very bad action."}], "gold_label": 0, "gold_annotations": [1, 0]}
10 | {"id": "id_9", "actions": [{"id": "action_id_9_0", "description": "A bad action."}, {"id": "action_id_9_1", "description": "A very good action."}], "gold_label": 1, "gold_annotations": [0, 1]}
11 | {"id": "id_10", "actions": [{"id": "action_id_10_0", "description": "good!"}, {"id": "action_id_10_1", "description": "This action is bad."}], "gold_label": 0, "gold_annotations": [1, 0]}
12 | {"id": "id_11", "actions": [{"id": "action_id_11_0", "description": "bad!"}, {"id": "action_id_11_1", "description": "This action is good."}], "gold_label": 1, "gold_annotations": [0, 1]}
13 | {"id": "id_12", "actions": [{"id": "action_id_12_0", "description": "A most good thing to do."}, {"id": "action_id_12_1", "description": "Very bad."}], "gold_label": 0, "gold_annotations": [1, 0]}
14 | {"id": "id_13", "actions": [{"id": "action_id_13_0", "description": "A most bad thing to do."}, {"id": "action_id_13_1", "description": "Very good."}], "gold_label": 1, "gold_annotations": [0, 1]}
15 | {"id": "id_14", "actions": [{"id": "action_id_14_0", "description": "This description is about a good action."}, {"id": "action_id_14_1", "description": "Don't choose this bad thing."}], "gold_label": 0, "gold_annotations": [1, 0]}
16 | {"id": "id_15", "actions": [{"id": "action_id_15_0", "description": "This description is about a bad action."}, {"id": "action_id_15_1", "description": "Do choose this good thing."}], "gold_label": 1, "gold_annotations": [0, 1]}
17 |
--------------------------------------------------------------------------------
/tests/scruples/__init__.py:
--------------------------------------------------------------------------------
1 | """Tests for the ``scruples`` package."""
2 |
--------------------------------------------------------------------------------
/tests/scruples/analysis/__init__.py:
--------------------------------------------------------------------------------
1 | """Tests for ``scruples.analysis``."""
2 |
--------------------------------------------------------------------------------
/tests/scruples/baselines/__init__.py:
--------------------------------------------------------------------------------
1 | """Tests for ``scruples.baselines``."""
2 |
--------------------------------------------------------------------------------
/tests/scruples/baselines/test_labels.py:
--------------------------------------------------------------------------------
1 | """Tests for scruples.baselines.labels."""
2 |
3 | import unittest
4 |
5 | import numpy as np
6 |
7 | from scruples.baselines import labels
8 |
9 |
10 | class PriorBaselineTestCase(unittest.TestCase):
11 | """Test scruples.baselines.labels.PriorBaseline."""
12 |
13 | FEATURES = [
14 | [0, 1],
15 | [0, 1],
16 | [1, 0]
17 | ]
18 | LABELS = ['a', 'a', 'b']
19 |
20 | def test_predicts_most_frequent_label(self):
21 | baseline = labels.PriorBaseline
22 | baseline.fit(X=self.FEATURES, y=self.LABELS)
23 | predictions = baseline.predict(self.FEATURES)
24 |
25 | self.assertEqual(predictions.tolist(), ['a', 'a', 'a'])
26 |
27 | def test_predicts_probabilities_by_class_prior(self):
28 | baseline = labels.PriorBaseline
29 | baseline.fit(X=self.FEATURES, y=self.LABELS)
30 | probabilities = baseline.predict_proba(self.FEATURES)
31 |
32 | self.assertAlmostEqual(
33 | probabilities.tolist(),
34 | [
35 | [2./3., 1./3.],
36 | [2./3., 1./3.],
37 | [2./3., 1./3.]
38 | ])
39 |
40 |
41 | class StratifiedBaselineTestCase(unittest.TestCase):
42 | """Test scruples.baselines.labels.StratifiedBaseline."""
43 |
44 | FEATURES = [
45 | [0, 1],
46 | [0, 1],
47 | [1, 0]
48 | ]
49 | LABELS = ['a', 'a', 'b']
50 |
51 | # the number of trials to perform in statistically testing that the
52 | # predictions adhere to the label frequencies from the training data
53 | N_TRIALS = 10000
54 |
55 | def test_predicts_random_label_by_class_probability(self):
56 | baseline = labels.StratifiedBaseline
57 | baseline.fit(X=self.FEATURES, y=self.LABELS)
58 | predictions = baseline.predict([[0, 0] for _ in range(self.N_TRIALS)])
59 |
60 | elements, counts = np.unique(predictions, return_counts=True)
61 |
62 | # test that all labels are predicted
63 | self.assertEqual(elements.tolist(), ['a', 'b'])
64 |
65 | # check that the mean counts for each label are within 5
66 | # standard deviations of their expectations
67 | std = ((2./3. * 1./3.) / self.N_TRIALS)**0.5
68 | a_mean = counts[0] / self.N_TRIALS
69 | b_mean = counts[1] / self.N_TRIALS
70 | self.assertGreater(a_mean, 2./3. - 5 * std)
71 | self.assertLess(a_mean, 2./3. + 5 * std)
72 | self.assertGreater(b_mean, 1./3. - 5 * std)
73 | self.assertLess(b_mean, 1./3. + 5 * std)
74 |
--------------------------------------------------------------------------------
/tests/scruples/baselines/test_linear.py:
--------------------------------------------------------------------------------
1 | """Tests for scruples.baselines.linear."""
2 |
3 | import unittest
4 |
5 | from scruples.baselines import linear
6 | from .utils import BaselineTestMixin
7 |
8 |
9 | class LogististicRegressionBaselineTestCase(
10 | BaselineTestMixin,
11 | unittest.TestCase
12 | ):
13 | """Test the logistic regression on bag-of-ngrams baseline."""
14 |
15 | BASELINE_MODEL = linear.LogisticRegressionBaseline
16 | BASELINE_HYPER_PARAMETERS = linear.LOGISTIC_REGRESSION_HYPER_PARAMETERS
17 | DATASET = 'corpus'
18 |
19 |
20 | class LogististicRankerBaselineTestCase(
21 | BaselineTestMixin,
22 | unittest.TestCase
23 | ):
24 | """Test the logistic ranker on bag-of-ngrams baseline."""
25 |
26 | BASELINE_MODEL = linear.LogisticRankerBaseline
27 | BASELINE_HYPER_PARAMETERS = linear.LOGISTIC_RANKER_HYPER_PARAMETERS
28 | DATASET = 'resource'
29 |
--------------------------------------------------------------------------------
/tests/scruples/baselines/test_loss.py:
--------------------------------------------------------------------------------
1 | """Tests for scruples.baselines.loss."""
2 |
3 | import unittest
4 |
5 | import numpy as np
6 | import torch as th
7 |
8 | from scruples.baselines import loss
9 |
10 |
11 | class TestSoftCrossEntropyLoss(unittest.TestCase):
12 | """Test SoftCrossEntropyLoss."""
13 |
14 | def test_forward(self):
15 | f = loss.SoftCrossEntropyLoss()
16 |
17 | self.assertAlmostEqual(
18 | f(
19 | th.Tensor([[0., 0.]]),
20 | th.Tensor([[0.5, 0.5]])
21 | ).item(),
22 | 0.6931471805599453)
23 | self.assertAlmostEqual(
24 | f(
25 | th.Tensor([[0., 0.], [-1.0986123, 1.0986123]]),
26 | th.Tensor([[0.3, 0.7], [0.6, 0.4]])
27 | ).item(),
28 | 1.0584212213097515)
29 |
30 |
31 | class TestDirichletMultinomialLoss(unittest.TestCase):
32 | """Test DirichletMultinomialLoss."""
33 |
34 | def test_forward(self):
35 | f = loss.DirichletMultinomialLoss()
36 |
37 | # when [a, b] goes to [0, 0], the dirichlet prior becomes a
38 | # Bernoulli distribution with p = 0.5, where there's a 50/50
39 | # chance that the label always comes up 0 or always comes up 1
40 | self.assertAlmostEqual(
41 | f(
42 | th.Tensor([[np.log(1e-20), np.log(1e-20)]]),
43 | th.Tensor([[3, 0]])
44 | ).item(),
45 | 0.6931471805599453,
46 | places=4)
47 | # test other cases
48 | self.assertAlmostEqual(
49 | f(
50 | th.Tensor([[np.log(1.), np.log(1.)]]),
51 | th.Tensor([[0.5, 0.5]])
52 | ).item(),
53 | 0.9347116558304358)
54 | self.assertAlmostEqual(
55 | f(
56 | th.Tensor(
57 | [[np.log(1.), np.log(1.)],
58 | [np.log(0.5), np.log(2.)]]
59 | ),
60 | th.Tensor([[0.3, 0.7], [0.6, 0.4]])
61 | ).item(),
62 | 1.1093992405423625)
63 |
--------------------------------------------------------------------------------
/tests/scruples/baselines/test_metrics.py:
--------------------------------------------------------------------------------
1 | """Tests for scruples.baselines.metrics."""
2 |
--------------------------------------------------------------------------------
/tests/scruples/baselines/test_naivebayes.py:
--------------------------------------------------------------------------------
1 | """Tests for scruples.baselines.naivebayes."""
2 |
3 | import unittest
4 |
5 | from scruples.baselines import naivebayes
6 | from .utils import BaselineTestMixin
7 |
8 |
9 | class BernoulliNBBaselineTestCase(
10 | BaselineTestMixin,
11 | unittest.TestCase
12 | ):
13 | """Test the bernoulli naive bayes on bag-of-ngrams baseline."""
14 |
15 | BASELINE_MODEL = naivebayes.BernoulliNBBaseline
16 | BASELINE_HYPER_PARAMETERS = naivebayes.BERNOULLINB_HYPER_PARAMETERS
17 | DATASET = 'corpus'
18 |
19 |
20 | class MultinomialNBBaselineTestCase(
21 | BaselineTestMixin,
22 | unittest.TestCase
23 | ):
24 | """Test the multinomial naive bayes on bag-of-ngrams baseline."""
25 |
26 | BASELINE_MODEL = naivebayes.MultinomialNBBaseline
27 | BASELINE_HYPER_PARAMETERS = naivebayes.MULTINOMIALNB_HYPER_PARAMETERS
28 | DATASET = 'corpus'
29 |
30 |
31 | class ComplementNBBaselineTestCase(
32 | BaselineTestMixin,
33 | unittest.TestCase
34 | ):
35 | """Test the complement naive bayes on bag-of-ngrams baseline."""
36 |
37 | BASELINE_MODEL = naivebayes.ComplementNBBaseline
38 | BASELINE_HYPER_PARAMETERS = naivebayes.COMPLEMENTNB_HYPER_PARAMETERS
39 | DATASET = 'corpus'
40 |
--------------------------------------------------------------------------------
/tests/scruples/baselines/test_utils.py:
--------------------------------------------------------------------------------
1 | """Tests for scruples.baselines.utils."""
2 |
3 | import unittest
4 | from unittest.mock import Mock
5 |
6 | import numpy as np
7 | import pandas as pd
8 |
9 | from scruples.baselines import utils
10 |
11 |
12 | class ConcatTitleAndTextTestCase(unittest.TestCase):
13 | """Test concat_title_and_text."""
14 |
15 | def test_concatenates_title_and_text_columns(self):
16 | features = pd.DataFrame({
17 | 'title': ['The Title A', 'The Title B'],
18 | 'text': ['The text A.', 'The text B.']
19 | })
20 |
21 | self.assertEqual(
22 | utils.concat_title_and_text(features).tolist(),
23 | ['The Title A\nThe text A.', 'The Title B\nThe text B.'])
24 |
25 |
26 | class DirichletMultinomialTestCase(unittest.TestCase):
27 | """Test dirichlet_multinomial."""
28 |
29 | def test_dirichlet_multinomial(self):
30 | self.assertTrue(np.allclose(
31 | utils.dirichlet_multinomial(
32 | [[np.log(0.1), np.log(0.1)]]
33 | ).tolist(),
34 | [[0.5, 0.5]]))
35 | self.assertTrue(np.allclose(
36 | utils.dirichlet_multinomial(
37 | [[np.log(1.), np.log(1.)],
38 | [np.log(1.), np.log(4.)]]
39 | ).tolist(),
40 | [[0.5, 0.5], [0.2, 0.8]]))
41 | self.assertTrue(np.allclose(
42 | utils.dirichlet_multinomial(
43 | [[np.log(1.), np.log(1.), np.log(1.)]]
44 | ).tolist(),
45 | [[1/3., 1/3., 1/3.]]))
46 | self.assertTrue(np.allclose(
47 | utils.dirichlet_multinomial(
48 | [[np.log(1.), np.log(2.), np.log(1.)],
49 | [np.log(2.), np.log(3.), np.log(5.)]]
50 | ).tolist(),
51 | [[0.25, 0.5, 0.25], [0.2, 0.3, 0.5]]))
52 |
53 |
54 | class ResourceTransformerTestCase(unittest.TestCase):
55 | """Test ResourceTransformer."""
56 |
57 | def test_set_params(self):
58 | # mock out the inputs
59 | transformer = Mock()
60 | transformer.get_params.return_value = {'x': None}
61 | second_transformer = Mock()
62 | second_transformer.get_params.return_value = {'x': None}
63 | third_transformer = Mock()
64 | third_transformer.get_params.return_value = {'x': None}
65 |
66 | resource_transformer = utils.ResourceTransformer(
67 | transformer=transformer)
68 |
69 | # test setting params on the transformer attribute
70 | resource_transformer.set_params(transformer__x=1)
71 |
72 | transformer.set_params.assert_called()
73 | transformer.set_params.assert_called_with(x=1)
74 |
75 | # test setting params on the ResourceTransformer
76 | self.assertNotEqual(
77 | resource_transformer.transformer,
78 | second_transformer)
79 |
80 | resource_transformer.set_params(
81 | transformer=second_transformer)
82 |
83 | self.assertEqual(
84 | resource_transformer.transformer,
85 | second_transformer)
86 |
87 | # test setting params on the ResourceTransformer and the new
88 | # transformer at the same time
89 | self.assertNotEqual(
90 | resource_transformer.transformer,
91 | third_transformer)
92 |
93 | resource_transformer.set_params(
94 | transformer=third_transformer,
95 | transformer__x='foo')
96 |
97 | self.assertEqual(
98 | resource_transformer.transformer,
99 | third_transformer)
100 |
101 | third_transformer.set_params.asset_called()
102 | third_transformer.set_params.assert_called_with(x='foo')
103 |
104 | def test_fit(self):
105 | # create the data
106 | X = pd.DataFrame([
107 | {'action0': 1, 'action1': 4},
108 | {'action0': 2, 'action1': 5},
109 | {'action0': 3, 'action1': 6},
110 | ])
111 |
112 | # mock arguments to instantiate ResourceTransformer
113 | transformer = Mock()
114 |
115 | # create the ResourceTransformer instance
116 | resource_transformer = utils.ResourceTransformer(
117 | transformer=transformer)
118 |
119 | # run tests
120 | self.assertIsInstance(
121 | resource_transformer.fit(X),
122 | utils.ResourceTransformer)
123 |
124 | transformer.fit.assert_called()
125 | self.assertEqual(
126 | transformer.fit.call_args[0][0].tolist(),
127 | pd.concat([X['action0'], X['action1']]).tolist())
128 |
129 | def test_transform(self):
130 | # create the data
131 | X = pd.DataFrame([
132 | {'action0': 1, 'action1': 4},
133 | {'action0': 2, 'action1': 4},
134 | {'action0': 3, 'action1': 4}
135 | ])
136 |
137 | X_transformed = [6, 4, 2]
138 |
139 | class TimesTwoTransformer(object):
140 | def fit(self, X, y = None):
141 | return self
142 |
143 | def transform(self, X):
144 | return 2 * X
145 |
146 | resources = ['action0', 'action1']
147 | transformer = TimesTwoTransformer()
148 |
149 | resource_transformer = utils.ResourceTransformer(
150 | transformer=transformer)
151 |
152 | resource_transformer.fit(X)
153 |
154 | self.assertEqual(
155 | resource_transformer.transform(X).tolist(),
156 | X_transformed)
157 |
--------------------------------------------------------------------------------
/tests/scruples/baselines/utils.py:
--------------------------------------------------------------------------------
1 | """Utilities for testing baselines."""
2 |
3 | import os
4 | import tempfile
5 |
6 | import pandas as pd
7 | import pytest
8 | from sklearn import metrics
9 | from skopt import BayesSearchCV
10 |
11 | from scruples import settings as scruples_settings
12 | from scruples.dataset import readers
13 | from ... import settings, utils
14 |
15 |
16 | # classes
17 |
18 | class BaselineTestMixin:
19 | """Test a baseline model by verifying it solves an easy dataset.
20 |
21 | Attributes
22 | ----------
23 | BASELINE_MODEL : sklearn.base.BaseEstimator
24 | The baseline model to test.
25 | BASELINE_HYPER_PARAMETERS : Dict
26 | The dictionary defining the hyper-parameter search space for the
27 | baseline model.
28 | DATASET : str
29 | The dataset against which the baseline should be run. Must be
30 | either ``"resource"`` or ``"corpus"``.
31 |
32 | Examples
33 | --------
34 | To create a test case for a baseline model, inherit from this class
35 | along with ``unittest.TestCase`` and provide the ``BASELINE_MODEL``
36 | ``BASELINE_HYPER_PARAMETERS``, and ``DATASET`` class attributes::
37 |
38 | class LogisticRegressionBaselineTestCase(
39 | BaselineTestMixin,
40 | unittest.TestCase
41 | ):
42 | '''Test the logistic regression baseline.'''
43 |
44 | BASELINE_MODEL = LogisticRegressionBaseline
45 | BASELINE_HYPER_PARAMETERS = LOGISTIC_REGRESSION_HYPER_PARAMS
46 | DATASET = 'corpus'
47 |
48 | """
49 |
50 | BASELINE_MODEL = None
51 | BASELINE_HYPER_PARAMETERS = None
52 | DATASET = None
53 |
54 | def setUp(self):
55 | super().setUp()
56 |
57 | # validate the class
58 |
59 | if self.BASELINE_MODEL is None:
60 | raise ValueError(
61 | 'Subclasses of BaselineTestMixin must provide a'
62 | ' BASELINE_MODEL class attribute.')
63 |
64 | if self.BASELINE_HYPER_PARAMETERS is None:
65 | raise ValueError(
66 | 'Subclasses of BaselineTestMixin must provide a'
67 | ' BASELINE_HYPER_PARAMETERS class attribute.')
68 |
69 | if self.DATASET is None:
70 | raise ValueError(
71 | 'Subclasses of BaselineTestMixin must provide a DATASET'
72 | ' class attribute.')
73 |
74 | if self.DATASET not in ['resource', 'corpus']:
75 | raise ValueError(
76 | 'The DATASET class attribute must either be'
77 | ' "resource", or "corpus".')
78 |
79 | # copy the dataset fixture from the package to disk
80 |
81 | if self.DATASET == 'resource':
82 | Reader = readers.ScruplesResource
83 | fixture_path = settings.RESOURCE_EASY_DIR
84 | split_filename_template =\
85 | scruples_settings.RESOURCE_FILENAME_TEMPLATE
86 | elif self.DATASET == 'corpus':
87 | Reader = readers.ScruplesCorpus
88 | fixture_path = settings.CORPUS_EASY_DIR
89 | split_filename_template =\
90 | scruples_settings.CORPUS_FILENAME_TEMPLATE
91 |
92 | self.temp_dir = tempfile.TemporaryDirectory()
93 |
94 | for split in Reader.SPLITS:
95 | split_filename = split_filename_template.format(split=split)
96 | utils.copy_pkg_resource_to_disk(
97 | pkg='tests',
98 | src=os.path.join(fixture_path, split_filename),
99 | dst=os.path.join(self.temp_dir.name, split_filename))
100 |
101 | # load the dataset
102 |
103 | self.dataset = Reader(data_dir=self.temp_dir.name)
104 |
105 | def tearDown(self):
106 | self.temp_dir.cleanup()
107 |
108 | @pytest.mark.slow
109 | def test_it_solves_scruples_easy_when_untuned(self):
110 | baseline = self.BASELINE_MODEL
111 |
112 | # train the model
113 | _, train_features, train_labels, train_label_scores =\
114 | self.dataset.train
115 | baseline.fit(train_features, train_labels)
116 |
117 | # predict with the model on dev
118 | _, dev_features, dev_labels, dev_label_scores =\
119 | self.dataset.dev
120 | predictions = baseline.predict(dev_features)
121 |
122 | # check that the accuracy is 100%
123 | self.assertEqual(
124 | metrics.accuracy_score(
125 | y_true=dev_labels,
126 | y_pred=predictions),
127 | 1.)
128 |
129 | @pytest.mark.slow
130 | def test_it_solves_scruples_easy_when_tuned(self):
131 | baseline = BayesSearchCV(
132 | self.BASELINE_MODEL,
133 | self.BASELINE_HYPER_PARAMETERS,
134 | n_iter=16,
135 | n_points=2,
136 | cv=4,
137 | n_jobs=1,
138 | refit=True)
139 |
140 | # train the model, tuning hyper-parameters
141 | _, train_features, train_labels, train_label_scores =\
142 | self.dataset.train
143 | baseline.fit(train_features, train_labels)
144 |
145 | # predict with the model on dev
146 | _, dev_features, dev_labels, dev_label_scores =\
147 | self.dataset.dev
148 | predictions = baseline.predict(dev_features)
149 |
150 | # check that the accuracy is 100%
151 | self.assertEqual(
152 | metrics.accuracy_score(
153 | y_true=dev_labels,
154 | y_pred=predictions),
155 | 1.)
156 |
--------------------------------------------------------------------------------
/tests/scruples/data/__init__.py:
--------------------------------------------------------------------------------
1 | """Tests for ``scruples.data``."""
2 |
--------------------------------------------------------------------------------
/tests/scruples/data/test_utils.py:
--------------------------------------------------------------------------------
1 | """Tests for scruples.data.utils."""
2 |
3 | import unittest
4 |
5 | import attr
6 |
7 | from scruples.data import utils
8 |
9 |
10 | class InstantiateAttrsWithExtraKwargsTestCase(unittest.TestCase):
11 | """Test instantiate_attrs_with_extra_kwargs."""
12 |
13 | class NonAttrsClass:
14 | def __init__(self, foo):
15 | self.foo = foo
16 |
17 | @attr.s
18 | class AttrsClass:
19 | foo = attr.ib()
20 |
21 | def test_raises_error_on_non_attrs_classes(self):
22 | with self.assertRaisesRegex(
23 | ValueError,
24 | r'.* must be decorated with @attr\.s'
25 | ):
26 | utils.instantiate_attrs_with_extra_kwargs(
27 | self.NonAttrsClass, foo=1)
28 |
29 | def test_instantiates_class_with_no_extra_kwargs(self):
30 | instance = utils.instantiate_attrs_with_extra_kwargs(
31 | self.AttrsClass, foo=1)
32 |
33 | self.assertEqual(instance.foo, 1)
34 |
35 | def test_instantiates_class_with_extra_kwargs(self):
36 | instance = utils.instantiate_attrs_with_extra_kwargs(
37 | self.AttrsClass, foo='a', bar='b')
38 |
39 | self.assertEqual(instance.foo, 'a')
40 |
41 |
42 | class CachedPropertyTestCase(unittest.TestCase):
43 | """Test cached_property."""
44 |
45 | def test_makes_method_into_property(self):
46 | class Foo:
47 | @utils.cached_property
48 | def bar(self):
49 | return 1
50 |
51 | foo = Foo()
52 |
53 | self.assertEqual(foo.bar, 1)
54 |
55 | def test_caches_property_from_method(self):
56 | class Foo:
57 | @utils.cached_property
58 | def bar(self):
59 | # this method (if not cached) will increment it's return
60 | # value based on the number of times it has been called
61 | self.num_calls = 1 + getattr(self, 'num_calls', 0)
62 |
63 | return self.num_calls
64 |
65 | foo = Foo()
66 |
67 | # calling bar once should return 1
68 | self.assertEqual(foo.bar, 1)
69 | # if bar is called more than once (i.e., the method is not
70 | # cached) then it will return something greater than one
71 | self.assertEqual(foo.bar, 1)
72 |
--------------------------------------------------------------------------------
/tests/scruples/dataset/__init__.py:
--------------------------------------------------------------------------------
1 | """Tests for ``scruples.dataset``."""
2 |
--------------------------------------------------------------------------------
/tests/scruples/demos/__init__.py:
--------------------------------------------------------------------------------
1 | """Tests for ``scruples.demos``."""
2 |
--------------------------------------------------------------------------------
/tests/scruples/demos/norms/__init__.py:
--------------------------------------------------------------------------------
1 | """Tests for ``scruples.demos.norms``."""
2 |
--------------------------------------------------------------------------------
/tests/scruples/demos/norms/test_utils.py:
--------------------------------------------------------------------------------
1 | """Tests for scruples.demos.norms.utils."""
2 |
3 | import unittest
4 | from unittest import mock
5 |
6 | from scruples.demos.norms import utils
7 |
8 |
9 | class PredictionDatasetTestCase(unittest.TestCase):
10 | """Test scruples.demos.norms.utils.PredictionDataset."""
11 |
12 | def setUp(self):
13 | self.features = [[1, 2], [1, 0]]
14 | self.transform = lambda x: x
15 |
16 | self.dataset = utils.PredictionDataset(
17 | features=self.features,
18 | transform=self.transform)
19 |
20 | def test___init__(self):
21 | self.assertEqual(self.dataset.features, self.features)
22 | self.assertEqual(self.dataset.transform, self.transform)
23 |
24 | def test___len__(self):
25 | self.assertEqual(len(self.dataset), len(self.features))
26 |
27 | def test___get_item__(self):
28 | for i in range(len(self.features)):
29 | self.assertEqual(self.dataset[i], self.features[i])
30 |
31 | def test_it_applies_transform(self):
32 | features = self.features
33 | mock_transform = mock.MagicMock(return_value='foo')
34 |
35 | dataset = utils.PredictionDataset(
36 | features=features,
37 | transform=mock_transform)
38 |
39 | self.assertEqual(dataset[1], 'foo')
40 | mock_transform.assert_called_with(features[1])
41 |
--------------------------------------------------------------------------------
/tests/scruples/demos/scoracle/__init__.py:
--------------------------------------------------------------------------------
1 | """Tests for ``scruples.demos.scoracle``."""
2 |
--------------------------------------------------------------------------------
/tests/scruples/extraction/__init__.py:
--------------------------------------------------------------------------------
1 | """Tests for ``scruples.extraction``."""
2 |
--------------------------------------------------------------------------------
/tests/scruples/extraction/test_base.py:
--------------------------------------------------------------------------------
1 | """Tests for scruples.extraction.base."""
2 |
3 | import unittest
4 |
5 | from scruples.extraction import base
6 |
7 |
8 | class LoggedCallableTestCase(unittest.TestCase):
9 | """Test scruples.extraction.base.LoggedCallable."""
10 |
11 | # classes for testing LoggedCallable
12 |
13 | class Add1(base.LoggedCallable):
14 | def apply(self, y):
15 | return y + 1
16 |
17 | class AddX(base.LoggedCallable):
18 | def __init__(self, x, *args, **kwargs):
19 | super().__init__(*args, **kwargs)
20 |
21 | self.x = x
22 |
23 | def apply(self, y):
24 | return self.x + y
25 |
26 | # tests
27 |
28 | def test_apply_raises_not_implemented(self):
29 | logged_callable = base.LoggedCallable()
30 |
31 | with self.assertRaises(NotImplementedError):
32 | logged_callable.apply()
33 |
34 | with self.assertRaises(NotImplementedError):
35 | logged_callable()
36 |
37 | def test_subclass_without_args(self):
38 | add1 = self.Add1()
39 |
40 | self.assertEqual(add1(0), 1)
41 | self.assertEqual(add1(1), 2)
42 | self.assertEqual(add1(5), 6)
43 |
44 | self.assertEqual(
45 | add1.call_log,
46 | [
47 | (
48 | # inputs
49 | (
50 | # args
51 | (0,),
52 | # kwargs
53 | {}
54 | ),
55 | # outputs
56 | 1
57 | ),
58 | (
59 | # inputs
60 | (
61 | # args
62 | (1,),
63 | # kwargs
64 | {}
65 | ),
66 | # outputs
67 | 2
68 | ),
69 | (
70 | # inputs
71 | (
72 | # args
73 | (5,),
74 | # kwargs
75 | {}
76 | ),
77 | # outputs
78 | 6
79 | )
80 | ])
81 |
82 | def test_subclass_with_args(self):
83 | add5 = self.AddX(x=5)
84 |
85 | self.assertEqual(add5(0), 5)
86 | self.assertEqual(add5(1), 6)
87 | self.assertEqual(add5(5), 10)
88 |
89 | self.assertEqual(
90 | add5.call_log,
91 | [
92 | (
93 | # inputs
94 | (
95 | # args
96 | (0,),
97 | # kwargs
98 | {}
99 | ),
100 | # outputs
101 | 5
102 | ),
103 | (
104 | # inputs
105 | (
106 | # args
107 | (1,),
108 | # kwargs
109 | {}
110 | ),
111 | # outputs
112 | 6
113 | ),
114 | (
115 | # inputs
116 | (
117 | # args
118 | (5,),
119 | # kwargs
120 | {}
121 | ),
122 | # outputs
123 | 10
124 | )
125 | ])
126 |
127 | def test_when_log_calls_is_false(self):
128 | add1 = self.Add1(log_calls=False)
129 |
130 | self.assertEqual(add1(0), 1)
131 | self.assertEqual(add1(1), 2)
132 | self.assertEqual(add1(5), 6)
133 |
134 | self.assertEqual(add1.call_log, None)
135 |
136 | add5 = self.AddX(x=5, log_calls=False)
137 |
138 | self.assertEqual(add5(0), 5)
139 | self.assertEqual(add5(1), 6)
140 | self.assertEqual(add5(5), 10)
141 |
142 | self.assertEqual(add5.call_log, None)
143 |
144 |
145 | class CaseTestCase(unittest.TestCase):
146 | """Test scruples.extraction.base.Case."""
147 |
148 | class SomeToNoneCase(base.Case):
149 | def match(self, x):
150 | return (x, x is not None)
151 |
152 | def transform(self, x):
153 | return None
154 |
155 | def filter(self, x):
156 | return False
157 |
158 | # tests
159 |
160 | def test_subclass_when_case_matches(self):
161 | self.assertEqual(
162 | self.SomeToNoneCase()(1),
163 | (None, True))
164 | self.assertEqual(
165 | self.SomeToNoneCase()('a'),
166 | (None, True))
167 | self.assertEqual(
168 | self.SomeToNoneCase()(''),
169 | (None, True))
170 |
171 | def test_subclass_when_case_does_not_match(self):
172 | self.assertEqual(
173 | self.SomeToNoneCase()(None),
174 | (None, False))
175 |
--------------------------------------------------------------------------------
/tests/scruples/extraction/test_transformers.py:
--------------------------------------------------------------------------------
1 | """Tests for scruples.extraction.transformers."""
2 |
3 | import unittest
4 |
5 | from scruples.extraction import transformers
6 |
7 |
8 | class GerundifyingTransformerTestCase(unittest.TestCase):
9 | """Test scruples.extraction.transformers.GerundifyingTransformer."""
10 |
11 | def test_leaves_gerund_phrases_alone(self):
12 | f = transformers.GerundifyingTransformer()
13 |
14 | self.assertEqual(
15 | f('Running down the street'),
16 | 'running down the street')
17 | self.assertEqual(
18 | f('Running down the street and waving goodbye to my friend.'),
19 | 'running down the street and waving goodbye to my friend.')
20 |
21 | def test_transforms_infinitives(self):
22 | f = transformers.GerundifyingTransformer()
23 |
24 | # simple case
25 | self.assertEqual(
26 | f('to run through a stop light'),
27 | 'running through a stop light')
28 | self.assertEqual(
29 | f('to wave to a friend.'),
30 | 'waving to a friend.')
31 | # coordinated verbs
32 | self.assertEqual(
33 | f('to run through and totally ignore a stop light'),
34 | 'running through and totally ignoring a stop light')
35 | self.assertEqual(
36 | f('to wave to a friend and say goodbye.'),
37 | 'waving to a friend and saying goodbye.')
38 |
39 | def test_transforms_I_phrases(self):
40 | f = transformers.GerundifyingTransformer()
41 |
42 | # simple case
43 | self.assertEqual(
44 | f('I ran through a stop light'),
45 | 'running through a stop light')
46 | self.assertEqual(
47 | f('I waved to a friend.'),
48 | 'waving to a friend.')
49 | # coordinated verbs
50 | self.assertEqual(
51 | f('I ran through and totally ignored a stop light'),
52 | 'running through and totally ignoring a stop light')
53 | self.assertEqual(
54 | f('I waved to a friend and said goodbye.'),
55 | 'waving to a friend and saying goodbye.')
56 |
57 | def test_drops_auxiliary_verbs(self):
58 | f = transformers.GerundifyingTransformer()
59 |
60 | # will
61 | # simple case
62 | self.assertEqual(
63 | f('I will run through a stop light'),
64 | 'running through a stop light')
65 | self.assertEqual(
66 | f('I won\'t run through a stop light'),
67 | 'not running through a stop light')
68 | self.assertEqual(
69 | f('I will wave to a friend.'),
70 | 'waving to a friend.')
71 | self.assertEqual(
72 | f('I won\'t wave to a friend.'),
73 | 'not waving to a friend.')
74 | # coordinated verbs
75 | self.assertEqual(
76 | f('I will run through and totally ignore a stop light'),
77 | 'running through and totally ignoring a stop light')
78 | self.assertEqual(
79 | f('I won\'t run through and totally ignore a stop light'),
80 | 'not running through and totally ignoring a stop light')
81 | self.assertEqual(
82 | f('I will wave to a friend and say goodbye.'),
83 | 'waving to a friend and saying goodbye.')
84 | self.assertEqual(
85 | f('I won\'t wave to a friend and say goodbye.'),
86 | 'not waving to a friend and saying goodbye.')
87 | # do
88 | # simple case
89 | self.assertEqual(
90 | f('I did run through a stop light'),
91 | 'running through a stop light')
92 | self.assertEqual(
93 | f('I didn\'t run through a stop light'),
94 | 'not running through a stop light')
95 | self.assertEqual(
96 | f('I did wave to a friend.'),
97 | 'waving to a friend.')
98 | self.assertEqual(
99 | f('I didn\'t wave to a friend.'),
100 | 'not waving to a friend.')
101 | # coordinated verbs
102 | self.assertEqual(
103 | f('I did run through and totally ignore a stop light'),
104 | 'running through and totally ignoring a stop light')
105 | self.assertEqual(
106 | f('I didn\'t run through and totally ignore a stop light'),
107 | 'not running through and totally ignoring a stop light')
108 | self.assertEqual(
109 | f('I did wave to a friend and say goodbye.'),
110 | 'waving to a friend and saying goodbye.')
111 | self.assertEqual(
112 | f('I didn\'t wave to a friend and say goodbye.'),
113 | 'not waving to a friend and saying goodbye.')
114 |
115 | def test_handles_auxiliary_forms_of_to_be(self):
116 | f = transformers.GerundifyingTransformer()
117 |
118 | # to be when it is _not_ auxiliary
119 | self.assertEqual(
120 | f("I'm mad at my friend"),
121 | 'being mad at my friend')
122 | self.assertEqual(
123 | f("I'm happy all the time."),
124 | 'being happy all the time.')
125 | # to be when it is auxiliary
126 | self.assertEqual(
127 | f("I'm thinking about making some coffee"),
128 | 'thinking about making some coffee')
129 | self.assertEqual(
130 | f("I'm running for president."),
131 | 'running for president.')
132 |
133 | def test_handles_Im(self):
134 | f = transformers.GerundifyingTransformer()
135 |
136 | self.assertEqual(
137 | f('Im happy to see you'),
138 | 'being happy to see you')
139 | self.assertEqual(
140 | f('Im sure it is.'),
141 | 'being sure it is.')
142 |
--------------------------------------------------------------------------------
/tests/scruples/test_settings.py:
--------------------------------------------------------------------------------
1 | """Tests for scruples.settings."""
2 |
--------------------------------------------------------------------------------
/tests/settings.py:
--------------------------------------------------------------------------------
1 | """Settings for the tests package."""
2 |
3 |
4 | # fixture locations
5 |
6 | # easy version dataset fixtures
7 | CORPUS_EASY_DIR = 'fixtures/corpus-easy'
8 | RESOURCE_EASY_DIR = 'fixtures/resource-easy'
9 |
10 | # raw data fixtures
11 | TEST_COMMENT_PATH = 'fixtures/raw-data/test-comment.jsonl'
12 | TEST_POST_PATH = 'fixtures/raw-data/test-post.jsonl'
13 | TEST_POST_COMMENTS_PATH = 'fixtures/raw-data/test-post-comments.jsonl'
14 |
15 | # gold extraction examples
16 | GOLD_TITLE_DESCRIPTION_EXTRACTIONS_PATH =\
17 | 'fixtures/gold-extractions/gold-title-description-extractions.jsonl'
18 |
--------------------------------------------------------------------------------
/tests/utils.py:
--------------------------------------------------------------------------------
1 | """Utilities for writing tests."""
2 |
3 | import pkg_resources
4 |
5 |
6 | def copy_pkg_resource_to_disk(
7 | pkg: str,
8 | src: str,
9 | dst: str
10 | ) -> None:
11 | """Copy a package resource to disk.
12 |
13 | Copy the resource from the package ``pkg`` at the path ``src`` to
14 | disk at the path ``dst``.
15 |
16 | Parameters
17 | ----------
18 | pkg : str
19 | The package holding the resource.
20 | src : str
21 | The source path for the resource in ``pkg``.
22 | dst : str
23 | The destination path for the resource on disk.
24 |
25 | Notes
26 | -----
27 | This function is primarily useful for testing code that requires
28 | resources to be written on disk, when those test fixtures are
29 | shipped in the package.
30 | """
31 | with pkg_resources.resource_stream(pkg, src) as src_file,\
32 | open(dst, 'wb') as dst_file:
33 | dst_file.write(src_file.read())
34 |
--------------------------------------------------------------------------------