├── spam ├── __init__.py ├── .gitignore ├── .notebooks ├── requirements.txt ├── README.md ├── download_data.sh ├── utils.py └── 02_spam_data_augmentation_tutorial.py ├── drybell ├── __init__.py ├── .scripts ├── requirements.txt ├── data │ ├── raw_data.parquet │ └── celebrity_knowledge_base.txt ├── drybell_lfs_spark.py ├── drybell_lfs.py ├── drybell_dask.py ├── drybell_spark.py └── README.md ├── crowdsourcing ├── __init__.py ├── .notebooks ├── requirements.txt ├── README.md ├── download-data.sh ├── data.py └── crowdsourcing_tutorial.py ├── getting_started ├── __init__.py ├── .notebooks ├── .gitignore ├── requirements.txt ├── img │ └── Overview.png ├── README.md ├── download_data.sh ├── utils.py └── getting_started.py ├── visual_relation ├── __init__.py ├── .notebooks ├── requirements.txt ├── README.md ├── download_full_data.sh ├── download_sample_data.sh ├── utils.py ├── model.py └── visual_relation_tutorial.py ├── spouse ├── .notebooks ├── requirements.txt ├── .gitignore ├── imgs │ └── sentence.jpg ├── README.md ├── download_data.sh ├── utils.py ├── preprocessors.py ├── tf_model.py └── spouse_demo.py ├── recsys ├── .notebooks ├── requirements.txt ├── README.md ├── utils.py └── recsys_tutorial.py ├── multitask ├── .notebooks ├── requirements.txt ├── README.md └── utils.py ├── .gitattributes ├── .dockerignore ├── requirements.txt ├── pyproject.toml ├── .flake8 ├── .github ├── pull_request_template.md ├── ISSUE_TEMPLATE │ ├── feature_request.md │ └── bug_report.md └── workflows │ └── stale.yml ├── Dockerfile ├── scripts ├── push-tutorials.sh ├── docker_launch.py ├── get_tox_envs.py └── build.py ├── tox.ini ├── .web.yml ├── .gitignore ├── .travis.yml ├── CONTRIBUTING.md ├── README.md └── LICENSE /spam/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /drybell/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /crowdsourcing/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /getting_started/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /visual_relation/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /spouse/.notebooks: -------------------------------------------------------------------------------- 1 | spouse_demo 2 | -------------------------------------------------------------------------------- /recsys/.notebooks: -------------------------------------------------------------------------------- 1 | recsys_tutorial 2 | -------------------------------------------------------------------------------- /getting_started/.notebooks: -------------------------------------------------------------------------------- 1 | getting_started -------------------------------------------------------------------------------- /multitask/.notebooks: -------------------------------------------------------------------------------- 1 | multitask_tutorial 2 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | *.ipynb linguist-generated 2 | -------------------------------------------------------------------------------- /spam/.gitignore: -------------------------------------------------------------------------------- 1 | # Logs 2 | results/ 3 | logs/ -------------------------------------------------------------------------------- /crowdsourcing/.notebooks: -------------------------------------------------------------------------------- 1 | crowdsourcing_tutorial 2 | -------------------------------------------------------------------------------- /visual_relation/.notebooks: -------------------------------------------------------------------------------- 1 | visual_relation_tutorial -------------------------------------------------------------------------------- /drybell/.scripts: -------------------------------------------------------------------------------- 1 | drybell_dask.py 2 | drybell_spark.py 3 | -------------------------------------------------------------------------------- /getting_started/.gitignore: -------------------------------------------------------------------------------- 1 | # Logs 2 | results/ 3 | logs/ -------------------------------------------------------------------------------- /spouse/requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib==3.1.1 2 | tensorflow==1.14.0 3 | -------------------------------------------------------------------------------- /multitask/requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib==3.1.1 2 | tensorboardX>=1.6,<2.0 3 | -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | .git 2 | .tox 3 | .vscode 4 | build 5 | *cache* 6 | */data 7 | *env* 8 | -------------------------------------------------------------------------------- /recsys/requirements.txt: -------------------------------------------------------------------------------- 1 | gdown==3.8.1 2 | numpy==1.16.4 3 | tensorflow==1.14.0 4 | textblob==0.15.3 -------------------------------------------------------------------------------- /spouse/.gitignore: -------------------------------------------------------------------------------- 1 | # Stored dev, train, test dataframes and labels, and dbpedia entries. 2 | *.pkl 3 | -------------------------------------------------------------------------------- /getting_started/requirements.txt: -------------------------------------------------------------------------------- 1 | pandas>=0.24.0,<0.25.0 2 | scikit-learn>=0.20.2 3 | textblob==0.15.3 4 | -------------------------------------------------------------------------------- /spam/.notebooks: -------------------------------------------------------------------------------- 1 | 01_spam_tutorial 2 | 02_spam_data_augmentation_tutorial 3 | 03_spam_data_slicing_tutorial 4 | -------------------------------------------------------------------------------- /spouse/imgs/sentence.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snorkel-team/snorkel-tutorials/HEAD/spouse/imgs/sentence.jpg -------------------------------------------------------------------------------- /drybell/requirements.txt: -------------------------------------------------------------------------------- 1 | dask[complete]>=2.1.0,<3.0.0 2 | pyarrow>=0.14.1,<0.15.0 3 | pyspark==2.4.3 4 | spacy>=2.1.0,<3.0.0 -------------------------------------------------------------------------------- /drybell/data/raw_data.parquet: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snorkel-team/snorkel-tutorials/HEAD/drybell/data/raw_data.parquet -------------------------------------------------------------------------------- /crowdsourcing/requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib==3.1.1 2 | pytorch-transformers==1.0.0 3 | tensorflow==1.14.0 4 | textblob==0.15.3 5 | -------------------------------------------------------------------------------- /getting_started/img/Overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snorkel-team/snorkel-tutorials/HEAD/getting_started/img/Overview.png -------------------------------------------------------------------------------- /drybell/data/celebrity_knowledge_base.txt: -------------------------------------------------------------------------------- 1 | jones ken 2 | jonesington kenneth 3 | jennifer smith 4 | todd smith 5 | pedro sanchez 6 | milly chan 7 | -------------------------------------------------------------------------------- /visual_relation/requirements.txt: -------------------------------------------------------------------------------- 1 | numpy>=1.16.0 2 | pandas>=0.24.0,<0.25.0 3 | torchvision==0.6.0 4 | Pillow==6.2.0 5 | tensorboardX==1.8 6 | svn 7 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | #### ESSENTIAL LIBRARIES 2 | 3 | snorkel==0.9.9 4 | 5 | 6 | #### DEV TOOLS 7 | 8 | black>=19.3b0 9 | flake8>=3.7.8 10 | jupyter>=1.0.0 11 | jupytext>=1.2.0 12 | nbconvert>=6.5.3 13 | -------------------------------------------------------------------------------- /spam/requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib==3.5.3 2 | names==0.3.0 3 | nltk==3.4.5 4 | numpy>=1.16.5 5 | pandas>=1.0.0 6 | scikit-learn>=0.20.2 7 | spacy>=2.1.6,<2.2.0 8 | tensorflow>=1.14.0 9 | textblob==0.15.3 10 | -------------------------------------------------------------------------------- /spouse/README.md: -------------------------------------------------------------------------------- 1 | # Spouse Tutorial 2 | In this tutorial, we show how Snorkel can be used for Information Extraction. We walk through an example text classification task for information extraction, where we use labeling functions involving keywords and distant supervision. 3 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 88 3 | target-version = ['py36', 'py37', 'py38'] 4 | exclude = ''' 5 | /( 6 | \.eggs 7 | | \.git 8 | | \.mypy_cache 9 | | \.tox 10 | | \.env.* 11 | | \.venv 12 | | _build 13 | | build 14 | | dist 15 | )/ 16 | ''' 17 | -------------------------------------------------------------------------------- /getting_started/README.md: -------------------------------------------------------------------------------- 1 | # Getting Started Tutorial 2 | In this quick walkthrough, we preview the high level workflow and interfaces of Snorkel using a canonical machine learning problem: classifying spam. We provide a quick overview of how to use Snorkel's three data operators: labeling functions, transformation functions, and slicing functions. 3 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = 3 | E203, 4 | E265, # For Jupytext percent format 5 | E266, 6 | E402, 7 | E501, 8 | E731, 9 | E741, 10 | W503, 11 | W605, 12 | F403 13 | exclude = 14 | .eggs, 15 | .git, 16 | .mypy_cache, 17 | .tox, 18 | .env**, 19 | .venv**, 20 | _build, 21 | build, 22 | dist 23 | -------------------------------------------------------------------------------- /.github/pull_request_template.md: -------------------------------------------------------------------------------- 1 | ## Description of proposed changes 2 | 3 | ## Related issue(s) 4 | 5 | Fixes # (issue) 6 | 7 | ## Test plan 8 | 9 | ## Checklist 10 | 11 | Need help on these? Just ask! 12 | 13 | * [ ] I have read the **CONTRIBUTING** document. 14 | * [ ] I have verified that my changes are covered by continuous integration. 15 | * [ ] All new and existing tests passed. 16 | -------------------------------------------------------------------------------- /visual_relation/README.md: -------------------------------------------------------------------------------- 1 | # Visual Relation Tutorial 2 | In this tutorial, we show how Snorkel can be applied to an image classification task. 3 | We consider the task of identifying visual relationships such as `riding` and `carrying` between objects in an image. 4 | We use categorical and spatial intuition about how objects in a visual relationship interact with each other to create labeling functions, and then use a label model to obtain high quality training labels for training a deep neural network to perform the task. 5 | -------------------------------------------------------------------------------- /spam/README.md: -------------------------------------------------------------------------------- 1 | # Spam Tutorials 2 | We consider a canonical machine learning problem: classifying spam. This directory contains three tutorials, described below: 3 | * `01_spam_tutorial`: This tutorial dives deep into how we can create, analyze, and use labeling functions for the spam classification task. 4 | * `02_spam_data_augmentation_tutorial`: This tutorial demonstrates how to write, combine and apply transformation functions for performing data augmentation. 5 | * `03_spam_data_slicing_tutorial`: This tutorial shows how we can use slicing functions to identify important slices of the data, for monitoring and improved performance. 6 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Let us know about something new you want 4 | --- 5 | 6 | ## Is your feature request related to a problem? Please describe. 7 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 8 | 9 | ## Describe the solution you'd like 10 | A clear and concise description of what you want to happen. 11 | 12 | ## Describe alternatives you've considered 13 | A clear and concise description of any alternative solutions or features you've considered. 14 | 15 | ## Additional context 16 | Add any other context or screenshots about the feature request here. 17 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.6-slim-stretch 2 | RUN apt-get update \ 3 | && apt-get install -y \ 4 | apt-utils \ 5 | build-essential \ 6 | git \ 7 | wget \ 8 | curl \ 9 | unzip 10 | 11 | WORKDIR /app 12 | COPY . /app 13 | ARG TUTORIAL 14 | RUN wget -O snorkel-requirements.txt \ 15 | https://raw.githubusercontent.com/snorkel-team/snorkel/master/requirements.txt \ 16 | && pip3 install -r $TUTORIAL/requirements.txt \ 17 | && pip3 install -r requirements.txt \ 18 | && pip3 install -r snorkel-requirements.txt \ 19 | && python3 -m spacy download en_core_web_sm 20 | 21 | WORKDIR $TUTORIAL 22 | ENTRYPOINT ["jupyter", "notebook", "--ip=0.0.0.0", "--no-browser", "--allow-root"] 23 | -------------------------------------------------------------------------------- /recsys/README.md: -------------------------------------------------------------------------------- 1 | # Recommender Systems Tutorial 2 | In this tutorial, we provide a simple walkthrough of how to use Snorkel to build a recommender system. 3 | We consider a setting similar to the [Netflix challenge](https://www.kaggle.com/netflix-inc/netflix-prize-data), but with books instead of movies. 4 | We use book metadata and review text to create labeling functions that estimate user ratings for books they have read or plan to read. 5 | We then use Snorkel's `LabelModel` to combine the outputs of those labeling functions, and train a model to predict whether a user will read and like any given book (and therefore what books should be recommended to the user) based only on what books the user has interacted with in the past. 6 | -------------------------------------------------------------------------------- /crowdsourcing/README.md: -------------------------------------------------------------------------------- 1 | # Crowdsourcing Tutorial 2 | In this tutorial, we provide a simple walkthrough of how to use Snorkel in conjunction with crowdsourcing to create a training set for a sentiment analysis task. 3 | We use a dataset where we have crowdsourced labels for about half of the training data points. 4 | The crowdsourced labels are fairly accurate, but do not cover the entire training dataset, nor are they available for the test set or during inference. 5 | To make up for their lack of training set coverage, we combine crowdsourced labels with heuristic labeling functions to increase the number of training labels we have. 6 | Then we combine the training labels using a `LabelModel` and use its outputs to train a deep learning model which can be applied to new, unseen data to automatically make predictions. 7 | -------------------------------------------------------------------------------- /multitask/README.md: -------------------------------------------------------------------------------- 1 | # Multitask Learning Tutorial 2 | Multi-task learning, or training a single model on multiple tasks, is becoming a standard tool for the modern ML practioner (see Ruder's [survey](http://ruder.io/multi-task/) from 2017 for a nice overview). 3 | It often leads to computational gains (one model performing many tasks takes up less memory and storage) as well as performance gains (learning to do well on a related _auxiliary_ task can improve the model's ability on the _primary_ task). 4 | While the primary purpose of the Snorkel project is to support training data creation and management, it also comes with a PyTorch-based modeling framework intended to support flexible multi-task learning (e.g. slice-aware models). 5 | In this tutorial, we introduce the basic interfaces and flow of multi-task learning tools within Snorkel. 6 | 7 | -------------------------------------------------------------------------------- /.github/workflows/stale.yml: -------------------------------------------------------------------------------- 1 | name: Mark/close stale issues and pull requests 2 | 3 | on: 4 | schedule: 5 | - cron: "0 12 * * *" 6 | 7 | jobs: 8 | stale: 9 | 10 | runs-on: ubuntu-latest 11 | 12 | steps: 13 | - uses: actions/stale@v1 14 | with: 15 | repo-token: ${{ secrets.GITHUB_TOKEN }} 16 | stale-issue-message: 'This issue is stale because it has been open 30 days with no activity. Remove stale label or comment or this will be closed in 5 days.' 17 | stale-pr-message: 'This pull request is stale because it has been open 30 days with no activity. Remove stale label or comment or this will be closed in 5 days.' 18 | stale-issue-label: 'no-issue-activity' 19 | stale-pr-label: 'no-pr-activity' 20 | exempt-issue-label: 'no-stale' 21 | days-before-stale: 30 22 | days-before-close: 5 23 | -------------------------------------------------------------------------------- /crowdsourcing/download-data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -euxo pipefail 3 | 4 | # Check that we are running from the right directory. 5 | if [ ! "${PWD##*/}" = "crowdsourcing" ]; then 6 | echo "Script must be run from crowdsourcing directory" >&2 7 | exit 1 8 | fi 9 | 10 | FILES=( "weather-non-agg-DFE.csv" "weather-evaluated-agg-DFE.csv" ) 11 | RELOAD=false 12 | 13 | # Check if at least any file is missing. If so, reload all data. 14 | for filename in "${FILES[@]}" 15 | do 16 | if [ ! -e "data/$filename" ]; then 17 | RELOAD=true 18 | fi 19 | done 20 | 21 | if [ "$RELOAD" = "true" ]; then 22 | if [ -d "data/" ]; then rm -Rf "data/"; fi 23 | mkdir -p data 24 | wget https://www.dropbox.com/s/94d2wsrrwh1ioyd/weather-non-agg-DFE.csv -P data 25 | wget https://www.dropbox.com/s/upz3ijyp7rztse6/weather-evaluated-agg-DFE.csv -P data 26 | fi 27 | -------------------------------------------------------------------------------- /spouse/download_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -euxo pipefail 3 | 4 | # Check that we are running from the right directory. 5 | if [ ! "${PWD##*/}" = "spouse" ]; then 6 | echo "Script must be run from spouse directory" >&2 7 | exit 1 8 | fi 9 | 10 | DATA_URL="https://www.dropbox.com/s/jmrvyaqew4zp9cy/spouse_data.zip" 11 | FILES=( "train_data.pkl" "dev_data.pkl" "test_data.pkl" "dbpedia.pkl" ) 12 | RELOAD=false 13 | 14 | # Check if at least any file is missing. If so, reload all data. 15 | for filename in "${FILES[@]}" 16 | do 17 | if [ ! -e "data/$filename" ]; then 18 | RELOAD=true 19 | fi 20 | done 21 | 22 | if [ "$RELOAD" = true ]; then 23 | if [ -d "data/" ]; then rm -Rf "data/"; fi 24 | mkdir -p data 25 | wget $DATA_URL -O data.zip --no-check-certificate 26 | cd data/ 27 | unzip ../data.zip 28 | rm ../data.zip 29 | rm trained_spouse_model 30 | fi 31 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Let us know about a bug you found 4 | --- 5 | 6 | ## Issue description 7 | 8 | A clear and concise description of what the bug is. 9 | 10 | ## Code example/repro steps 11 | 12 | Please try to provide a minimal example to repro the bug. 13 | Error messages and stack traces are also helpful. 14 | 15 | ## Expected behavior 16 | A clear and concise description of what you expected to happen. 17 | 18 | ## Screenshots 19 | If applicable, add screenshots to help explain your problem. 20 | No screenshots of code! 21 | 22 | ## System info 23 | 24 | * How you installed Snorkel (conda, pip, source): 25 | * Build command you used (if compiling from source): 26 | * OS: 27 | * Python version: 28 | * Snorkel version: 29 | * Versions of any other relevant libraries: 30 | 31 | ## Additional context 32 | Add any other context about the problem here. 33 | -------------------------------------------------------------------------------- /spam/download_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -euxo pipefail 3 | 4 | # Check that we are running from the right directory. 5 | if [ ! "${PWD##*/}" = "spam" ]; then 6 | echo "Script must be run from spam directory" >&2 7 | exit 1 8 | fi 9 | 10 | FILES=( "Youtube01-Psy.csv" "Youtube02-KatyPerry.csv" "Youtube03-LMFAO.csv" "Youtube04-Eminem.csv" "Youtube05-Shakira.csv" ) 11 | DATA_URL="http://archive.ics.uci.edu/ml/machine-learning-databases/00380/YouTube-Spam-Collection-v1.zip" 12 | RELOAD=false 13 | 14 | # Check if at least any file is missing. If so, reload all data. 15 | for filename in "${FILES[@]}" 16 | do 17 | if [ ! -e "data/$filename" ]; then 18 | RELOAD=true 19 | fi 20 | done 21 | 22 | if [ "$RELOAD" = true ]; then 23 | if [ -d "data/" ]; then rm -Rf "data/"; fi 24 | mkdir -p data 25 | curl $DATA_URL > data.zip 26 | mv data.zip data/ 27 | cd data 28 | unzip data.zip 29 | rm data.zip 30 | rm -rf __MACOSX 31 | cd .. 32 | fi 33 | -------------------------------------------------------------------------------- /getting_started/download_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -euxo pipefail 3 | 4 | # Check that we are running from the right directory. 5 | if [ ! "${PWD##*/}" = "getting_started" ]; then 6 | echo "Script must be run from getting_started directory" >&2 7 | exit 1 8 | fi 9 | 10 | FILES=( "Youtube01-Psy.csv" "Youtube02-KatyPerry.csv" "Youtube03-LMFAO.csv" "Youtube04-Eminem.csv" "Youtube05-Shakira.csv" ) 11 | DATA_URL="http://archive.ics.uci.edu/ml/machine-learning-databases/00380/YouTube-Spam-Collection-v1.zip" 12 | RELOAD=false 13 | 14 | # Check if at least any file is missing. If so, reload all data. 15 | for filename in "${FILES[@]}" 16 | do 17 | if [ ! -e "data/$filename" ]; then 18 | RELOAD=true 19 | fi 20 | done 21 | 22 | if [ "$RELOAD" = true ]; then 23 | if [ -d "data/" ]; then rm -Rf "data/"; fi 24 | mkdir -p data 25 | curl $DATA_URL > data.zip 26 | mv data.zip data/ 27 | cd data 28 | unzip data.zip 29 | rm data.zip 30 | rm -rf __MACOSX 31 | cd .. 32 | fi 33 | -------------------------------------------------------------------------------- /scripts/push-tutorials.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | setup_git() { 4 | git config --global user.email "54054737+snork2d2@users.noreply.github.com" 5 | git config --global user.name "snork2d2" 6 | } 7 | 8 | checkout_website() { 9 | git clone https://$GITHUB_TOKEN@github.com/snorkel-team/website.git 10 | } 11 | 12 | build_tutorials() { 13 | # Clear artifacts 14 | rm -rf build website/_use_cases website/_getting_started 15 | # Generate markdown files 16 | tox -e markdown 17 | # Special handling of getting_started.md 18 | mkdir -p website/_getting_started 19 | mv -f build/getting_started.md website/_getting_started 20 | # Move the rest of the tutorials 21 | mkdir -p website/_use_cases 22 | mv -f build/*.md website/_use_cases 23 | } 24 | 25 | push_tutorials() { 26 | cd website 27 | git add . -u 28 | git commit -m "[DEPLOY $TRAVIS_BUILD_NUMBER] Update tutorials" 29 | git push https://$GITHUB_TOKEN@github.com/snorkel-team/website.git master -f 30 | } 31 | 32 | echo "Setting up git" 33 | setup_git 34 | echo "Checking out website" 35 | checkout_website 36 | echo "Building tutorial web pages" 37 | build_tutorials 38 | echo "Pushing tutorial web pages" 39 | push_tutorials 40 | -------------------------------------------------------------------------------- /getting_started/utils.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import subprocess 4 | 5 | import pandas as pd 6 | 7 | 8 | def load_unlabeled_spam_dataset(): 9 | """Load spam training dataset without any labels.""" 10 | if os.path.basename(os.getcwd()) == "snorkel-tutorials": 11 | os.chdir("getting_started") 12 | try: 13 | subprocess.run(["bash", "download_data.sh"], check=True, stderr=subprocess.PIPE) 14 | except subprocess.CalledProcessError as e: 15 | print(e.stderr.decode()) 16 | raise e 17 | filenames = sorted(glob.glob("data/Youtube*.csv")) 18 | dfs = [] 19 | for i, filename in enumerate(filenames, start=1): 20 | df = pd.read_csv(filename) 21 | # Lowercase column names 22 | df.columns = map(str.lower, df.columns) 23 | # Rename fields 24 | df = df.rename(columns={"class": "label", "content": "text"}) 25 | # Remove comment_id, label fields 26 | df = df.drop("comment_id", axis=1) 27 | df = df.drop("label", axis=1) 28 | # Shuffle order 29 | df = df.sample(frac=1, random_state=123).reset_index(drop=True) 30 | dfs.append(df) 31 | return pd.concat(dfs) 32 | -------------------------------------------------------------------------------- /drybell/drybell_lfs_spark.py: -------------------------------------------------------------------------------- 1 | from pyspark.sql import Row 2 | from snorkel.labeling.lf import labeling_function 3 | from snorkel.labeling.lf.nlp_spark import spark_nlp_labeling_function 4 | from snorkel.preprocess import preprocessor 5 | 6 | from drybell_lfs import load_celebrity_knowledge_base 7 | 8 | ABSTAIN = -1 9 | NEGATIVE = 0 10 | POSITIVE = 1 11 | 12 | 13 | @preprocessor() 14 | def combine_text(x): 15 | return Row(title=x.title, body=x.body, article=f"{x.title} {x.body}") 16 | 17 | 18 | @spark_nlp_labeling_function(text_field="article", pre=[combine_text]) 19 | def article_mentions_person(x): 20 | for ent in x.doc.ents: 21 | if ent.label_ == "PERSON": 22 | return ABSTAIN 23 | return NEGATIVE 24 | 25 | 26 | @spark_nlp_labeling_function( 27 | text_field="article", 28 | pre=[combine_text], 29 | resources=dict(celebrity_knowledge_base=load_celebrity_knowledge_base()), 30 | ) 31 | def person_in_db(x, celebrity_knowledge_base): 32 | for ent in x.doc.ents: 33 | if ent.label_ == "PERSON" and ent.text.lower() in celebrity_knowledge_base: 34 | return POSITIVE 35 | return ABSTAIN 36 | 37 | 38 | @labeling_function() 39 | def body_contains_fortune(x): 40 | return POSITIVE if "fortune" in x.body else ABSTAIN 41 | -------------------------------------------------------------------------------- /drybell/drybell_lfs.py: -------------------------------------------------------------------------------- 1 | from snorkel.labeling.lf import labeling_function 2 | from snorkel.labeling.lf.nlp import nlp_labeling_function 3 | from snorkel.preprocess import preprocessor 4 | 5 | ABSTAIN = -1 6 | NEGATIVE = 0 7 | POSITIVE = 1 8 | 9 | 10 | @preprocessor() 11 | def combine_text(x): 12 | x.article = f"{x.title} {x.body}" 13 | return x 14 | 15 | 16 | @nlp_labeling_function(text_field="article", pre=[combine_text]) 17 | def article_mentions_person(x): 18 | for ent in x.doc.ents: 19 | if ent.label_ == "PERSON": 20 | return ABSTAIN 21 | return NEGATIVE 22 | 23 | 24 | def load_celebrity_knowledge_base(path="drybell/data/celebrity_knowledge_base.txt"): 25 | with open(path, "r") as f: 26 | return f.read().splitlines() 27 | 28 | 29 | @nlp_labeling_function( 30 | text_field="article", 31 | pre=[combine_text], 32 | resources=dict(celebrity_knowledge_base=load_celebrity_knowledge_base()), 33 | ) 34 | def person_in_db(x, celebrity_knowledge_base): 35 | for ent in x.doc.ents: 36 | if ent.label_ == "PERSON" and ent.text.lower() in celebrity_knowledge_base: 37 | return POSITIVE 38 | return ABSTAIN 39 | 40 | 41 | @labeling_function() 42 | def body_contains_fortune(x): 43 | return POSITIVE if "fortune" in x.body else ABSTAIN 44 | -------------------------------------------------------------------------------- /drybell/drybell_dask.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import dask.dataframe as dd 4 | from snorkel.labeling.model import LabelModel 5 | from snorkel.labeling.apply.dask import DaskLFApplier 6 | 7 | from drybell_lfs import article_mentions_person, body_contains_fortune, person_in_db 8 | 9 | logging.basicConfig(level=logging.INFO) 10 | 11 | 12 | def main(data_path, output_path): 13 | # Read data 14 | logging.info(f"Reading data from {data_path}") 15 | data = dd.read_parquet(data_path) 16 | data = data.repartition(npartitions=2) 17 | 18 | # Build label matrix 19 | logging.info("Applying LFs") 20 | lfs = [article_mentions_person, body_contains_fortune, person_in_db] 21 | applier = DaskLFApplier(lfs) 22 | L = applier.apply(data) 23 | 24 | # Train label model 25 | logging.info("Training label model") 26 | label_model = LabelModel(cardinality=2) 27 | label_model.fit(L) 28 | 29 | # Generate training labels 30 | logging.info("Generating probabilistic labels") 31 | y_prob = label_model.predict_proba(L)[:, 1] 32 | data = data.reset_index().set_index("index") 33 | data_labeled = data.assign(y_prob=dd.from_array(y_prob)) 34 | dd.to_parquet(data_labeled, output_path) 35 | logging.info(f"Labels saved to {output_path}") 36 | 37 | 38 | if __name__ == "__main__": 39 | main("drybell/data/raw_data.parquet", "drybell/data/labeled_data.parquet") 40 | -------------------------------------------------------------------------------- /visual_relation/download_full_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -euxo pipefail 3 | 4 | # Check that we are running from the right directory. 5 | if [ ! "${PWD##*/}" = "visual_relation" ]; then 6 | echo "Script must be run from visual_relation directory" >&2 7 | exit 1 8 | fi 9 | 10 | DIRS=("glove" "VRD/sg_dataset") 11 | 12 | RELOAD=false 13 | # Check if at least any file is missing. If so, reload all data. 14 | for directory_name in "${DIRS[@]}" 15 | do 16 | if [ ! -d "data/$directory_name" ]; then 17 | RELOAD=true 18 | fi 19 | done 20 | 21 | ANNOTATIONS_URL="https://www.dropbox.com/s/bnfhm6kt9xumik8/vrd.zip" 22 | IMAGES_URL="http://imagenet.stanford.edu/internal/jcjohns/visual_relations/sg_dataset.zip" 23 | GLOVE_URL="http://nlp.stanford.edu/data/wordvecs/glove.6B.zip" 24 | 25 | if [ "$RELOAD" = true ]; then 26 | if [ -d "data" ]; then rm -Rf "data"; fi 27 | mkdir -p data 28 | cd data 29 | 30 | # download and unzip metadata and annotations 31 | echo "Downloading full VRD dataset..." 32 | wget $ANNOTATIONS_URL 33 | unzip vrd.zip 34 | 35 | # Delete the zip files. 36 | rm vrd.zip 37 | cd VRD 38 | 39 | # Download and unzip all images 40 | wget $IMAGES_URL 41 | unzip sg_dataset.zip 42 | rm sg_dataset.zip 43 | cd .. 44 | 45 | mkdir -p glove 46 | cd glove 47 | 48 | wget $GLOVE_URL 49 | unzip glove.6B.zip 50 | 51 | # Delete the zip files 52 | rm glove.6B.zip 53 | cd ../.. 54 | fi 55 | 56 | -------------------------------------------------------------------------------- /drybell/drybell_spark.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import pyspark.sql.functions as F 4 | from pyspark import SparkContext 5 | from pyspark.sql import SQLContext 6 | from snorkel.labeling.model import LabelModel 7 | from snorkel.labeling.apply.spark import SparkLFApplier 8 | 9 | from drybell_lfs_spark import ( 10 | article_mentions_person, 11 | body_contains_fortune, 12 | person_in_db, 13 | ) 14 | 15 | logging.basicConfig(level=logging.INFO) 16 | 17 | 18 | def main(data_path, output_path): 19 | # Read data 20 | logging.info(f"Reading data from {data_path}") 21 | sc = SparkContext() 22 | sql = SQLContext(sc) 23 | data = sql.read.parquet(data_path) 24 | 25 | # Build label matrix 26 | logging.info("Applying LFs") 27 | lfs = [article_mentions_person, body_contains_fortune, person_in_db] 28 | applier = SparkLFApplier(lfs) 29 | L = applier.apply(data.rdd) 30 | 31 | # Train label model 32 | logging.info("Training label model") 33 | label_model = LabelModel(cardinality=2) 34 | label_model.fit(L) 35 | 36 | # Generate training labels 37 | logging.info("Generating probabilistic labels") 38 | y_prob = label_model.predict_proba(L)[:, 1] 39 | y_prob_sql_array = F.array([F.lit(y) for y in y_prob]) 40 | data_labeled = data.withColumn("y_prob", y_prob_sql_array) 41 | data_labeled.write.mode("overwrite").parquet(output_path) 42 | logging.info(f"Labels saved to {output_path}") 43 | 44 | 45 | if __name__ == "__main__": 46 | main("drybell/data/raw_data.parquet", "drybell/data/labeled_data_spark.parquet") 47 | -------------------------------------------------------------------------------- /spouse/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import subprocess 4 | from typing import Tuple 5 | 6 | import numpy as np 7 | 8 | import pandas as pd 9 | 10 | IS_TEST = os.environ.get("TRAVIS") == "true" or os.environ.get("IS_TEST") == "true" 11 | 12 | 13 | def load_data() -> Tuple[ 14 | Tuple[pd.DataFrame, np.ndarray], pd.DataFrame, Tuple[pd.DataFrame, np.ndarray] 15 | ]: 16 | """ 17 | Returns: 18 | df_dev, Y_dev: Development set data points and 1D labels ndarray. 19 | df_train: Training set data points dataframe. 20 | df_test, Y_test: Test set data points dataframe and 1D labels ndarray. 21 | """ 22 | try: 23 | subprocess.run(["bash", "download_data.sh"], check=True, stderr=subprocess.PIPE) 24 | except subprocess.CalledProcessError as e: 25 | print(e.stderr.decode()) 26 | raise e 27 | with open(os.path.join("data", "dev_data.pkl"), "rb") as f: 28 | df_dev = pickle.load(f) 29 | Y_dev = pickle.load(f) 30 | 31 | with open(os.path.join("data", "train_data.pkl"), "rb") as f: 32 | df_train = pickle.load(f) 33 | if IS_TEST: 34 | # Reduce train set size to speed up travis. 35 | df_train = df_train.iloc[:2000] 36 | 37 | with open(os.path.join("data", "test_data.pkl"), "rb") as f: 38 | df_test = pickle.load(f) 39 | Y_test = pickle.load(f) 40 | 41 | # Convert labels to {0, 1} format from {-1, 1} format. 42 | Y_dev = (1 + Y_dev) // 2 43 | Y_test = (1 + Y_test) // 2 44 | return ((df_dev, Y_dev), df_train, (df_test, Y_test)) 45 | 46 | 47 | def get_n_epochs() -> int: 48 | return 3 if IS_TEST else 30 49 | -------------------------------------------------------------------------------- /multitask/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Tuple 2 | 3 | import numpy as np 4 | from sklearn.model_selection import train_test_split 5 | 6 | 7 | DataSplits = Tuple[ 8 | Tuple[np.ndarray, np.ndarray], 9 | Tuple[np.ndarray, np.ndarray], 10 | Tuple[np.ndarray, np.ndarray], 11 | ] 12 | 13 | 14 | def make_circle_dataset(n: int, r: float, **kwargs: Any) -> DataSplits: 15 | X = np.random.uniform(0, 1, size=(n, 2)) * 2 - 1 16 | Y = (X[:, 0] ** 2 + X[:, 1] ** 2 < r).astype(int) 17 | return split_data(X, Y, **kwargs) 18 | 19 | 20 | def make_inv_circle_dataset(n: int, r: float, **kwargs: Any) -> DataSplits: 21 | X = np.random.uniform(0, 1, size=(n, 2)) * 2 - 1 22 | Y = (X[:, 0] ** 2 + X[:, 1] ** 2 > r).astype(int) 23 | return split_data(X, Y, **kwargs) 24 | 25 | 26 | def make_square_dataset(n: int, r: float, **kwargs: Any) -> DataSplits: 27 | X = np.random.uniform(0, 1, size=(n, 2)) * 2 - 1 28 | Y = ((abs(X[:, 0]) < r / 2) * (abs(X[:, 1]) < r / 2)).astype(int) 29 | return split_data(X, Y, **kwargs) 30 | 31 | 32 | def split_data( 33 | X: np.ndarray, 34 | Y: np.ndarray, 35 | splits: Tuple[float] = (0.8, 0.1, 0.1), 36 | seed: int = 123, 37 | ) -> DataSplits: 38 | """Split data twice using sklearn train_test_split helper.""" 39 | assert len(splits) == 3 40 | 41 | X_train, X_test, Y_train, Y_test = train_test_split( 42 | X, Y, test_size=splits[2], random_state=seed 43 | ) 44 | X_train, X_valid, Y_train, Y_valid = train_test_split( 45 | X_train, Y_train, test_size=splits[1] / sum(splits[:2]), random_state=seed 46 | ) 47 | 48 | return (X_train, Y_train), (X_valid, Y_valid), (X_test, Y_test) 49 | -------------------------------------------------------------------------------- /visual_relation/download_sample_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -euxo pipefail 3 | 4 | # Check that we are running from the right directory. 5 | if [ ! "${PWD##*/}" = "visual_relation" ]; then 6 | echo "Script must be run from visual_relation directory" >&2 7 | exit 1 8 | fi 9 | 10 | DIRS=("glove" "VRD/sg_dataset/samples") 11 | 12 | RELOAD=false 13 | # Check if at least any file is missing. If so, reload all data. 14 | for directory_name in "${DIRS[@]}" 15 | do 16 | if [ ! -d "data/$directory_name" ]; then 17 | RELOAD=true 18 | fi 19 | done 20 | 21 | ANNOTATIONS_URL="https://www.dropbox.com/s/bnfhm6kt9xumik8/vrd.zip" 22 | SAMPLE_IMAGES_URL="https://github.com/Prof-Lu-Cewu/Visual-Relationship-Detection.git" 23 | 24 | # NOTE: We download a smaller version of the 6B glove embeddings file, where 25 | # originally GLOVE_URL="http://nlp.stanford.edu/data/wordvecs/glove.6B.zip" 26 | GLOVE_URL="https://www.dropbox.com/s/2yg2r8931qx12xp/glove.100d.zip" 27 | 28 | if [ "$RELOAD" = true ]; then 29 | if [ -d "data" ]; then rm -Rf "data"; fi 30 | mkdir -p data 31 | cd data 32 | 33 | # download and unzip metadata and annotations 34 | wget $ANNOTATIONS_URL 35 | unzip vrd.zip 36 | 37 | # Delete the zip files. 38 | rm vrd.zip 39 | cd VRD 40 | 41 | # Download and unzip sample images 42 | echo "Downloading sample VRD dataset..." 43 | mkdir sg_dataset 44 | cd sg_dataset 45 | git clone $SAMPLE_IMAGES_URL 46 | mv Visual-Relationship-Detection/samples ./ 47 | rm -r Visual-Relationship-Detection 48 | cd ../.. 49 | 50 | mkdir -p glove 51 | cd glove 52 | 53 | wget $GLOVE_URL 54 | unzip glove.100d.zip 55 | 56 | # Delete the zip files 57 | rm glove.100d.zip 58 | cd ../.. 59 | fi 60 | 61 | -------------------------------------------------------------------------------- /drybell/README.md: -------------------------------------------------------------------------------- 1 | # Snorkel Drybell Example 2 | 3 | This example is based on the 4 | [Snorkel Drybell project](https://ai.googleblog.com/2019/03/harnessing-organizational-knowledge-for.html), 5 | a collaboration between the Snorkel team and Google to implement weak supervision at industrial scale. 6 | You can read more in the 7 | [blog post](https://ai.googleblog.com/2019/03/harnessing-organizational-knowledge-for.html) 8 | and [research paper (SIGMOD Industry, 2019)](https://arxiv.org/abs/1812.00417). 9 | The paper used a running example of classifying documents as containing a celebrity mention or not, 10 | which is what we use here as well. 11 | The data is a very small set of six faux newspaper articles and titles, stored as a 12 | [Parquet file](https://parquet.apache.org/): 13 | 14 | ``` 15 | Title Body 16 | ----- ---- 17 | Sports team wins the game! It was an exciting game. The team won at the end. 18 | Jennifer Smith donates entire fortune. She has a lot of money. Now she has less, because... 19 | ... 20 | ``` 21 | 22 | Of course, with such a small (and very fake) dataset, we don't expect to produce 23 | high quality models. 24 | The goal here is to demonstrate how Snorkel can be used in a large-scale production setting. 25 | We present two scripts — 26 | one using Snorkel's [Dask](https://dask.org/) interface 27 | and one using Snorkel's [Spark](https://spark.apache.org/) interface 28 | — that represent how Snorkel can be deployed as part of a pipeline. 29 | We also demonstrate Snorkel's `NLPLabelingFunction` interface, similar to the 30 | `NLPLabelingFunction` template presented in the Drybell paper. 31 | 32 | If you plan to execute these scripts, do so from the `snorkel-tutorials` directory: 33 | 34 | ```bash 35 | python3 drybell/drybell_dask.py 36 | 37 | # or 38 | 39 | python3 drybell/drybell_spark.py 40 | ``` 41 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [tox] 2 | skipsdist = true 3 | envlist = 4 | getting_started, 5 | spouse, 6 | visual_relation, 7 | spam, 8 | crowdsourcing, 9 | multitask, 10 | drybell, 11 | recsys, 12 | style, 13 | 14 | [testenv] 15 | description = test/sync for {envname} 16 | deps = 17 | -rrequirements.txt 18 | getting_started: -rgetting_started/requirements.txt 19 | spouse: -rspouse/requirements.txt 20 | spam: -rspam/requirements.txt 21 | multitask: -rmultitask/requirements.txt 22 | visual_relation: -rvisual_relation/requirements.txt 23 | crowdsourcing: -rcrowdsourcing/requirements.txt 24 | recsys: -rrecsys/requirements.txt 25 | drybell: -rdrybell/requirements.txt 26 | passenv = 27 | JAVA_HOME 28 | IS_TEST 29 | TRAVIS 30 | commands_pre = 31 | drybell: python -m spacy download en_core_web_sm 32 | # Available posargs: test, sync, html 33 | commands = 34 | getting_started: python {toxinidir}/scripts/build.py {posargs:test} getting_started 35 | spouse: python {toxinidir}/scripts/build.py {posargs:test} spouse 36 | spam: python {toxinidir}/scripts/build.py {posargs:test} spam 37 | multitask: python {toxinidir}/scripts/build.py {posargs:test} multitask 38 | visual_relation: python {toxinidir}/scripts/build.py {posargs:test} visual_relation 39 | crowdsourcing: python {toxinidir}/scripts/build.py {posargs:test} crowdsourcing 40 | recsys: python {toxinidir}/scripts/build.py {posargs:test} recsys 41 | drybell: python {toxinidir}/scripts/build.py {posargs:test} drybell 42 | 43 | [testenv:style] 44 | description = check the code style 45 | commands = 46 | black --check {toxinidir} 47 | flake8 {toxinidir} 48 | 49 | [testenv:fix] 50 | description = run code stylers 51 | commands = black {toxinidir} 52 | 53 | [testenv:markdown] 54 | description = generate markdown webpages 55 | commands = python {toxinidir}/scripts/build.py markdown {posargs} 56 | -------------------------------------------------------------------------------- /.web.yml: -------------------------------------------------------------------------------- 1 | # Tutorial configuration fields 2 | # notebook: either the path to the notebook, or optionally a directory name 3 | # (if there's only one notebook in the directory) 4 | # title (optional): title to display on website 5 | # description (optional): description to display on website 6 | # exclude_all_output (optional): exclude output from all cells? 7 | # 8 | # If either a title or description is not given, the notebook won't 9 | # be rendered on the Use Cases page. The order on the Use Cases page 10 | # is given by the order here. 11 | 12 | tutorials: 13 | - notebook: spam/01_spam_tutorial.ipynb 14 | title: Intro to Labeling Functions 15 | description: Labeling data for spam classification 16 | - notebook: spam/02_spam_data_augmentation_tutorial.ipynb 17 | title: Intro to Transformation Functions 18 | description: Data augmentation for spam classification 19 | - notebook: spam/03_spam_data_slicing_tutorial.ipynb 20 | title: Intro to Slicing Functions 21 | description: Monitoring critical data subsets for spam classification 22 | - notebook: crowdsourcing 23 | title: Hybrid Crowd Labeling Workflows in Snorkel 24 | description: Mixing programmatic and crowdworker labels for sentiment analysis 25 | - notebook: recsys 26 | title: Building Recommender Systems in Snorkel 27 | description: Labeling text reviews for book recommendations 28 | - notebook: visual_relation 29 | title: Visual Relation Detection in Snorkel 30 | description: Multi-class labeling for visual relationships in images 31 | - notebook: spouse 32 | title: Information Extraction in Snorkel 33 | description: Labeling spouse mentions in documents 34 | - notebook: multitask 35 | title: Intro to Snorkel's Multitask Learning System 36 | description: State-of-the-art framework for pretraining & parameter sharing 37 | - notebook: getting_started 38 | exclude_all_output: true 39 | -------------------------------------------------------------------------------- /scripts/docker_launch.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import errno 3 | import os 4 | import socket 5 | import subprocess 6 | 7 | 8 | def check_docker() -> None: 9 | try: 10 | subprocess.run(["docker", "--version"], check=True) 11 | except (FileNotFoundError, subprocess.CalledProcessError): 12 | raise ValueError("Error calling Docker. Is it installed?") 13 | 14 | 15 | def build_image(tutorial_name: str) -> None: 16 | arg = f"TUTORIAL={tutorial_name}" 17 | tag = f"--tag={tutorial_name}" 18 | subprocess.run(["docker", "build", "--build-arg", arg, tag, "."], check=True) 19 | 20 | 21 | def check_port(port: int) -> None: 22 | s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 23 | try: 24 | s.bind(("127.0.0.1", port)) 25 | except socket.error as e: 26 | if e.errno == errno.EADDRINUSE: 27 | raise ValueError(f"Port {port} is already in use") 28 | else: 29 | raise e 30 | 31 | 32 | def run_image(tutorial_name: str, port: int) -> None: 33 | tag = f"{tutorial_name}:latest" 34 | p_cfg = f"{port}:{port}" 35 | p_arg = f"--port={port}" 36 | check_run = subprocess.run(["docker", "run", "-it", "-p", p_cfg, tag, p_arg]) 37 | if check_run.returncode: 38 | raise ValueError( 39 | "Error running container. If you haven't built it yet, " 40 | "try running this script with the --build flag." 41 | ) 42 | 43 | 44 | def docker_launch(tutorial_name: str, build: bool, port: int) -> None: 45 | if not os.path.isdir(tutorial_name): 46 | raise ValueError(f"{tutorial_name} is not a valid tutorial") 47 | check_docker() 48 | if build: 49 | build_image(tutorial_name) 50 | check_port(port) 51 | run_image(tutorial_name, port) 52 | 53 | 54 | if __name__ == "__main__": 55 | desc = "Build and run Docker images for Snorkel Tutorials." 56 | parser = argparse.ArgumentParser(description=desc) 57 | parser.add_argument("tutorial_name", help="Name of the tutorial (directory)") 58 | parser.add_argument( 59 | "--build", action="store_true", default=False, help="Build the Docker image?" 60 | ) 61 | parser.add_argument("--port", type=int, default=8888, help="Jupyter port for host") 62 | args = parser.parse_args() 63 | docker_launch(**vars(args)) 64 | -------------------------------------------------------------------------------- /spouse/preprocessors.py: -------------------------------------------------------------------------------- 1 | # + 2 | from typing import Optional 3 | 4 | from snorkel.preprocess import preprocessor 5 | from snorkel.types import DataPoint 6 | 7 | 8 | # - 9 | 10 | 11 | @preprocessor() 12 | def get_person_text(cand: DataPoint) -> DataPoint: 13 | """ 14 | Returns the text for the two person mentions in candidate 15 | """ 16 | person_names = [] 17 | for index in [1, 2]: 18 | field_name = "person{}_word_idx".format(index) 19 | start = cand[field_name][0] 20 | end = cand[field_name][1] + 1 21 | person_names.append(" ".join(cand["tokens"][start:end])) 22 | cand.person_names = person_names 23 | return cand 24 | 25 | 26 | @preprocessor() 27 | def get_person_last_names(cand: DataPoint) -> DataPoint: 28 | """ 29 | Returns the last names for the two person mentions in candidate 30 | """ 31 | cand = get_person_text(cand) 32 | person1_name, person2_name = cand.person_names 33 | person1_lastname = ( 34 | person1_name.split(" ")[-1] if len(person1_name.split(" ")) > 1 else None 35 | ) 36 | person2_lastname = ( 37 | person2_name.split(" ")[-1] if len(person2_name.split(" ")) > 1 else None 38 | ) 39 | cand.person_lastnames = [person1_lastname, person2_lastname] 40 | return cand 41 | 42 | 43 | @preprocessor() 44 | def get_text_between(cand: DataPoint) -> DataPoint: 45 | """ 46 | Returns the text between the two person mentions in the sentence 47 | """ 48 | start = cand.person1_word_idx[1] + 1 49 | end = cand.person2_word_idx[0] 50 | cand.text_between = " ".join(cand.tokens[start:end]) 51 | return cand 52 | 53 | 54 | @preprocessor() 55 | def get_left_tokens(cand: DataPoint) -> DataPoint: 56 | """ 57 | Returns tokens in the length 3 window to the left of the person mentions 58 | """ 59 | # TODO: need to pass window as input params 60 | window = 3 61 | 62 | end = cand.person1_word_idx[0] 63 | cand.person1_left_tokens = cand.tokens[0:end][-1 - window : -1] 64 | 65 | end = cand.person2_word_idx[0] 66 | cand.person2_left_tokens = cand.tokens[0:end][-1 - window : -1] 67 | return cand 68 | 69 | 70 | # Helper function to get last name for dbpedia entries. 71 | def last_name(s: str) -> Optional[str]: 72 | name_parts = s.split(" ") 73 | return name_parts[-1] if len(name_parts) > 1 else None 74 | -------------------------------------------------------------------------------- /crowdsourcing/data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | from typing import Tuple 4 | 5 | import pandas as pd 6 | 7 | 8 | LABEL_MAPPING = { 9 | "Negative": 0, 10 | "Positive": 1, 11 | "I can't tell": 2, 12 | "Neutral / author is just sharing information": 2, 13 | "Tweet not related to weather condition": 2, 14 | } 15 | 16 | 17 | def load_data() -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame]: 18 | if os.path.basename(os.getcwd()) != "crowdsourcing": 19 | raise ValueError("Function must be called from crowdsourcing/ directory.") 20 | try: 21 | subprocess.run(["bash", "download-data.sh"], check=True, stderr=subprocess.PIPE) 22 | except subprocess.CalledProcessError as e: 23 | print(e.stderr.decode()) 24 | raise e 25 | 26 | gold_labels = pd.read_csv("data/weather-evaluated-agg-DFE.csv") 27 | gold_labels = gold_labels.set_index("tweet_id", drop=False) 28 | labeled = gold_labels[ 29 | (gold_labels["is_the_category_correct_for_this_tweet:confidence"] == 1) 30 | & ( 31 | (gold_labels.sentiment == "Positive") 32 | | (gold_labels.sentiment == "Negative") 33 | ) 34 | ] 35 | labeled = labeled.sample(frac=1, random_state=123) # Shuffle data points. 36 | 37 | crowd_labels = pd.read_csv("data/weather-non-agg-DFE.csv") 38 | # Keep only the tweets with available ground truth. 39 | crowd_labels = crowd_labels.join( 40 | labeled, on=["tweet_id"], lsuffix=".raw", rsuffix=".gold", how="inner" 41 | ) 42 | crowd_labels = crowd_labels[["tweet_id", "worker_id", "emotion"]] 43 | crowd_labels.emotion = crowd_labels.emotion.map(LABEL_MAPPING) 44 | crowd_labels = crowd_labels.rename(columns=dict(emotion="label")) 45 | crowd_labels = crowd_labels.set_index("tweet_id") 46 | crowd_labels = crowd_labels[crowd_labels["label"] != 2] 47 | 48 | df_dev = labeled[:50] 49 | df_dev = df_dev[["tweet_id", "tweet_text", "sentiment"]] 50 | df_dev.sentiment = df_dev.sentiment.map(LABEL_MAPPING).values 51 | # Remove half the labels 52 | crowd_labels = crowd_labels.drop(df_dev[: int(len(df_dev) / 2)].tweet_id) 53 | 54 | df_test = labeled[50:100] 55 | df_test = df_test[["tweet_id", "tweet_text", "sentiment"]] 56 | df_test.sentiment = df_test.sentiment.map(LABEL_MAPPING).values 57 | crowd_labels = crowd_labels.drop(df_test.tweet_id) 58 | 59 | df_train = labeled[100:][["tweet_id", "tweet_text"]] 60 | # Remove half the labels 61 | crowd_labels = crowd_labels.drop(df_train[: int(len(df_train) / 2)].tweet_id) 62 | 63 | return crowd_labels, df_train, df_dev, df_test 64 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | .hypothesis/ 50 | .pytest_cache/ 51 | 52 | # Translations 53 | *.mo 54 | *.pot 55 | 56 | # Django stuff: 57 | *.log 58 | local_settings.py 59 | db.sqlite3 60 | db.sqlite3-journal 61 | 62 | # Flask stuff: 63 | instance/ 64 | .webassets-cache 65 | 66 | # Scrapy stuff: 67 | .scrapy 68 | 69 | # Sphinx documentation 70 | docs/_build/ 71 | 72 | # PyBuilder 73 | target/ 74 | 75 | # Jupyter Notebook 76 | .ipynb_checkpoints 77 | 78 | # IPython 79 | profile_default/ 80 | ipython_config.py 81 | 82 | # pyenv 83 | .python-version 84 | 85 | # pipenv 86 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 87 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 88 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 89 | # install all needed dependencies. 90 | #Pipfile.lock 91 | 92 | # celery beat schedule file 93 | celerybeat-schedule 94 | 95 | # SageMath parsed files 96 | *.sage.py 97 | 98 | # Environments 99 | .env* 100 | .venv* 101 | env/ 102 | venv/ 103 | ENV/ 104 | env.bak/ 105 | venv.bak/ 106 | 107 | # Spyder project settings 108 | .spyderproject 109 | .spyproject 110 | 111 | # Rope project settings 112 | .ropeproject 113 | 114 | # mkdocs documentation 115 | /site 116 | 117 | # mypy 118 | .mypy_cache/ 119 | .dmypy.json 120 | dmypy.json 121 | 122 | # Pyre type checker 123 | .pyre/ 124 | 125 | # Editors 126 | .vscode/ 127 | *.swp 128 | 129 | # Dask 130 | dask-worker-space/ 131 | 132 | # MacOS 133 | .DS_Store 134 | 135 | # Generic data directories 136 | **/data/* 137 | 138 | # Nohup 139 | nohup.out 140 | 141 | # Modeling 142 | checkpoint/ 143 | checkpoints/ 144 | 145 | -------------------------------------------------------------------------------- /spouse/tf_model.py: -------------------------------------------------------------------------------- 1 | # + 2 | from typing import Tuple 3 | import numpy as np 4 | import pandas as pd 5 | 6 | import tensorflow as tf 7 | from tensorflow.keras.layers import ( 8 | Bidirectional, 9 | Concatenate, 10 | Dense, 11 | Embedding, 12 | Input, 13 | LSTM, 14 | ) 15 | 16 | 17 | def get_feature_arrays(df: pd.DataFrame) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: 18 | """Get np arrays of upto max_length tokens and person idxs.""" 19 | bet = df.between_tokens 20 | left = df.apply(lambda c: c.tokens[: c.person1_word_idx[0]][-4:-1], axis=1) 21 | right = df.person2_right_tokens 22 | 23 | def pad_or_truncate(l, max_length=40): 24 | return l[:max_length] + [""] * (max_length - len(l)) 25 | 26 | left_tokens = np.array(list(map(pad_or_truncate, left))) 27 | bet_tokens = np.array(list(map(pad_or_truncate, bet))) 28 | right_tokens = np.array(list(map(pad_or_truncate, right))) 29 | return left_tokens, bet_tokens, right_tokens 30 | 31 | 32 | def bilstm( 33 | tokens: tf.Tensor, 34 | rnn_state_size: int = 64, 35 | num_buckets: int = 40000, 36 | embed_dim: int = 36, 37 | ): 38 | ids = tf.strings.to_hash_bucket(tokens, num_buckets) 39 | embedded_input = Embedding(num_buckets, embed_dim)(ids) 40 | return Bidirectional(LSTM(rnn_state_size, activation=tf.nn.relu))( 41 | embedded_input, mask=tf.strings.length(tokens) 42 | ) 43 | 44 | 45 | def get_model( 46 | rnn_state_size: int = 64, num_buckets: int = 40000, embed_dim: int = 12 47 | ) -> tf.keras.Model: 48 | """ 49 | Return LSTM model for predicting label probabilities. 50 | Args: 51 | rnn_state_size: LSTM state size. 52 | num_buckets: Number of buckets to hash strings to integers. 53 | embed_dim: Size of token embeddings. 54 | Returns: 55 | model: A compiled LSTM model. 56 | """ 57 | left_ph = Input((None,), dtype="string") 58 | bet_ph = Input((None,), dtype="string") 59 | right_ph = Input((None,), dtype="string") 60 | left_embs = bilstm(left_ph, rnn_state_size, num_buckets, embed_dim) 61 | bet_embs = bilstm(bet_ph, rnn_state_size, num_buckets, embed_dim) 62 | right_embs = bilstm(right_ph, rnn_state_size, num_buckets, embed_dim) 63 | layer = Concatenate(1)([left_embs, bet_embs, right_embs]) 64 | layer = Dense(64, activation=tf.nn.relu)(layer) 65 | layer = Dense(32, activation=tf.nn.relu)(layer) 66 | probabilities = Dense(2, activation=tf.nn.softmax)(layer) 67 | model = tf.keras.Model(inputs=[bet_ph, left_ph, right_ph], outputs=probabilities) 68 | model.compile(tf.train.AdagradOptimizer(0.1), "categorical_crossentropy") 69 | return model 70 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | # Only run CI on master, release branches, tags, and PRs 2 | if: tag IS present OR type = pull_request OR ((branch = master OR branch =~ release-*) AND type = push) 3 | 4 | # Main dist is Python 5 | language: python 6 | 7 | # Cache package wheels 8 | cache: pip 9 | 10 | # python3.7 only available on xenial 11 | dist: xenial 12 | 13 | # Set JAVA_HOME 14 | env: 15 | global: 16 | - secure: "3HpKlZ72Bmljvtjrl5oSQgwBFSJ8ouEtueUwrz1bKhG2tLHnmtqmYgnfIup9YZtm++1YtcG7Q5sTR8KA1JWwnJL/OGLrAD0ddIbjv+MuGeUnsIvi5y/X5jD+rpvaPjcvbMTcJG+yPhYwlmTKYT3+gSkFYZMj83VfFKRgaazTIIyT13ZPoRENcpSScve/IAusSdlRlBt3Vp2rLlbYMu9hInKkc8Qde6vlpKyfEKPXO37EuTzszxaYr9kt7I1OHKvONI2nNry0X0+WwtpsfF3d+OVLE2z6CtbeolMvbqnFf162c34Z0OzjDpREfIfSorLhXHwuhu8josTH+Ftg+0CvEnmpFbnG3scUgzJGdEaAwB2CSd1qIxt/+gHRyL1meTAfP+eiaPtEcbciFDj+5iuV9udkNXIarKbGA/eBq8mlEXNQMeOZcHPwNpdZtI9/n8YRQGfLsQsfjOX7g22XVRi5DVrmfn9R0hkoQmGk3FXGJdPS+2bVMzmXQdMvmaZOvn79RiaQVPMQSXIE2tRlYhlDE93DzkL2Sb300JG7g6YxzBQtUUFdd1w2yipi6Mup5UouATNk/TaxVXpW6DaWVw58a552NqpIoksdPpYmmo/hCWajivXxw68Tj5OQHZ/puSQyQW17Ct9kMetVtWY68gGN3jT7i4DKaDheGDMqan0s3QM=" 17 | - JAVA_HOME=/usr/lib/jvm/java-8-openjdk-amd64 18 | 19 | jobs: 20 | include: 21 | - name: "Python 3.6: tutorial tests" 22 | python: 3.6 23 | - name: "Python 3.7: tutorial tests" 24 | python: 3.7 25 | 26 | jobs: 27 | include: 28 | - stage: "Python 3.6: fail fast tests" 29 | python: 3.6 30 | install: 31 | - pip install -U pip setuptools 32 | - pip install -U tox==3.13.0 33 | - tox --notest -e style 34 | script: tox -e style 35 | - stage: all tests 36 | if: type = pull_request 37 | name: "Python 3.6: affected tutorial tests" 38 | python: 3.6 39 | env: JAVA_HOME=/usr/lib/jvm/java-8-openjdk-amd64 40 | - if: type = pull_request 41 | name: "Python 3.7: affected tutorial tests" 42 | python: 3.7 43 | env: JAVA_HOME=/usr/lib/jvm/java-8-openjdk-amd64 44 | - if: type != pull_request 45 | name: "Python 3.6: all tutorial tests" 46 | python: 3.6 47 | env: ALL_ENVS=--all JAVA_HOME=/usr/lib/jvm/java-8-openjdk-amd64 48 | - if: type != pull_request 49 | name: "Python 3.7: all tutorial tests + web build" 50 | python: 3.7 51 | env: BUILD_WEB=true ALL_ENVS=--all JAVA_HOME=/usr/lib/jvm/java-8-openjdk-amd64 52 | 53 | # Install JDK8 for PySpark tests 54 | before_install: 55 | - sudo add-apt-repository -y ppa:openjdk-r/ppa 56 | - sudo apt-get -qq update 57 | - sudo apt-get install -y openjdk-8-jdk --no-install-recommends 58 | - sudo update-java-alternatives -s java-1.8.0-openjdk-amd64 59 | 60 | install: 61 | - pip install -U pip setuptools 62 | - pip install -U tox==3.13.0 63 | - python scripts/get_tox_envs.py --plan $ALL_ENVS 64 | - tox --notest -e $(python scripts/get_tox_envs.py $ALL_ENVS) 65 | 66 | script: 67 | - tox -e $(python scripts/get_tox_envs.py $ALL_ENVS) 68 | 69 | after_success: 70 | - if [[ ${BUILD_WEB} == "true" ]] && [[ ${TRAVIS_BRANCH} == "master" ]]; then 71 | echo "Deploying tutorials"; 72 | bash ./scripts/push-tutorials.sh; 73 | fi 74 | -------------------------------------------------------------------------------- /scripts/get_tox_envs.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import subprocess 4 | from typing import List 5 | 6 | EXTRA_ENVIRONMENTS = ["style"] 7 | SKIP_EXT = [".md", ".txt"] 8 | 9 | 10 | def get_modified_paths(no_travis_strict: bool) -> List[str]: 11 | # Call git diff --name-only HEAD $(git merge-base HEAD $TRAVIS_BRANCH) 12 | # to get paths affected by patch 13 | base_branch = os.environ.get("TRAVIS_BRANCH") 14 | if base_branch is None: 15 | if not no_travis_strict: 16 | raise ValueError("No environment variable $TRAVIS_BRANCH") 17 | base_branch = "master" 18 | merge_base = subprocess.run( 19 | ["git", "merge-base", "HEAD", base_branch], stdout=subprocess.PIPE 20 | ) 21 | cp = subprocess.run( 22 | ["git", "diff", "--name-only", "HEAD", merge_base.stdout.rstrip()], 23 | stdout=subprocess.PIPE, 24 | ) 25 | return [str(s, "utf-8") for s in cp.stdout.splitlines()] 26 | 27 | 28 | def get_default_environments() -> List[str]: 29 | # Call tox -l to get default environments 30 | cp = subprocess.run(["tox", "-l"], stdout=subprocess.PIPE) 31 | return [str(s, "utf-8") for s in cp.stdout.splitlines()] 32 | 33 | 34 | def get_changed_tox_envs(all_envs: bool, no_travis_strict: bool, plan: bool) -> None: 35 | # Check we're in the right place, otherwise git paths are messed up 36 | if os.path.split(os.getcwd())[1] != "snorkel-tutorials": 37 | raise ValueError("Execute this script from the snorkel-tutorials directory") 38 | # If we passed in --all flag, just run all environments 39 | default_environments = get_default_environments() 40 | if all_envs: 41 | if plan: 42 | print("Running all environments") 43 | print(",".join(default_environments)) 44 | return 45 | # Find paths modified in patch 46 | modified_paths = get_modified_paths(no_travis_strict) 47 | if plan: 48 | print(f"Modified paths: {','.join(modified_paths)}") 49 | # Find unique snorkel-tutorial subdirectories affected by patch 50 | unique_directories = set() 51 | for p in modified_paths: 52 | # Skip changed markdown / text files as they don't need a test env. 53 | if any(p.endswith(ext) for ext in SKIP_EXT): 54 | continue 55 | splits = p.split("/") 56 | # If there's a directory, parse it; otherwise, add placeholder "." 57 | unique_directories.add("." if len(splits) == 1 else splits[0]) 58 | unique_defaults = [d for d in unique_directories if d in default_environments] 59 | # If all changed directories are among the defaults, then only run them 60 | # plus EXTRA_ENVIRONMENTS. 61 | if len(unique_defaults) == len(unique_directories): 62 | run_environments = unique_defaults + EXTRA_ENVIRONMENTS 63 | if plan: 64 | print( 65 | f"Changed tutorial directories: {unique_defaults}, " 66 | f"running environments: {run_environments}" 67 | ) 68 | print(",".join(run_environments)) 69 | # Otherwise, run all environments 70 | else: 71 | if plan: 72 | print( 73 | "Change in non-tutorial directory. " 74 | f"All changes: [{unique_directories}]. " 75 | "Running on all environments." 76 | ) 77 | print(",".join(default_environments)) 78 | 79 | 80 | if __name__ == "__main__": 81 | parser = argparse.ArgumentParser() 82 | parser.add_argument( 83 | "--no-travis-strict", 84 | action="store_true", 85 | default=False, 86 | help="Don't fail if not in Travis?", 87 | ) 88 | parser.add_argument( 89 | "--plan", 90 | action="store_true", 91 | default=False, 92 | help="Print out plan for Travis execution?", 93 | ) 94 | parser.add_argument( 95 | "--all", 96 | dest="all_envs", 97 | action="store_true", 98 | default=False, 99 | help="Run all environments?", 100 | ) 101 | args = parser.parse_args() 102 | get_changed_tox_envs(**vars(args)) 103 | -------------------------------------------------------------------------------- /visual_relation/utils.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import json 3 | import os 4 | import subprocess 5 | 6 | import numpy as np 7 | import pandas as pd 8 | 9 | # %% 10 | 11 | 12 | def flatten_vrd_relationship(img, relationship, objects, predicates): 13 | """Create a per-relationship entry from a per-image entry JSON.""" 14 | new_relationship_dict = {} 15 | new_relationship_dict["subject_category"] = objects[ 16 | relationship["subject"]["category"] 17 | ] 18 | new_relationship_dict["object_category"] = objects[ 19 | relationship["object"]["category"] 20 | ] 21 | new_relationship_dict["subject_bbox"] = relationship["subject"]["bbox"] 22 | new_relationship_dict["object_bbox"] = relationship["object"]["bbox"] 23 | 24 | if predicates[relationship["predicate"]] == "ride": 25 | new_relationship_dict["label"] = 0 26 | elif predicates[relationship["predicate"]] == "carry": 27 | new_relationship_dict["label"] = 1 28 | else: 29 | new_relationship_dict["label"] = 2 30 | 31 | new_relationship_dict["source_img"] = img 32 | 33 | return new_relationship_dict 34 | 35 | 36 | # %% 37 | def vrd_to_pandas( 38 | relationships_set, objects, predicates, list_of_predicates, keys_list=None 39 | ): 40 | """Create Pandas DataFrame from JSON of relationships.""" 41 | relationships = [] 42 | 43 | for img in relationships_set: 44 | if (keys_list is None) or (img in keys_list): 45 | img_relationships = relationships_set[img] 46 | for relationship in img_relationships: 47 | predicate_idx = relationship["predicate"] 48 | if predicates[predicate_idx] in list_of_predicates: 49 | relationships.append( 50 | flatten_vrd_relationship(img, relationship, objects, predicates) 51 | ) 52 | else: 53 | continue 54 | return pd.DataFrame.from_dict(relationships) 55 | 56 | 57 | # %% 58 | def load_vrd_data(sample=False, is_travis=False): 59 | """Download and load Pandas DataFrame of VRD relationships. 60 | 61 | NOTE: Only loads semantic relationships as data points. 62 | """ 63 | 64 | if sample or is_travis: 65 | try: 66 | subprocess.run( 67 | ["bash", "download_sample_data.sh"], check=True, stderr=subprocess.PIPE 68 | ) 69 | except subprocess.CalledProcessError as e: 70 | print(e.stderr.decode()) 71 | raise e 72 | else: 73 | try: 74 | subprocess.run( 75 | ["bash", "download_full_data.sh"], check=True, stderr=subprocess.PIPE 76 | ) 77 | except subprocess.CalledProcessError as e: 78 | print(e.stderr.decode()) 79 | raise e 80 | 81 | relationships_train = json.load(open("data/VRD/annotations_train.json")) 82 | relationships_test = json.load(open("data/VRD/annotations_test.json")) 83 | 84 | objects = json.load(open("data/VRD/objects.json")) 85 | predicates = json.load(open("data/VRD/predicates.json")) 86 | semantic_predicates = [ 87 | "carry", 88 | "cover", 89 | "fly", 90 | "look", 91 | "lying on", 92 | "park on", 93 | "sit on", 94 | "stand on", 95 | "ride", 96 | ] 97 | 98 | np.random.seed(123) 99 | val_idx = list(np.random.choice(len(relationships_train), 1000, replace=False)) 100 | relationships_val = { 101 | key: value 102 | for i, (key, value) in enumerate(relationships_train.items()) 103 | if i in val_idx 104 | } 105 | relationships_train = { 106 | key: value 107 | for i, (key, value) in enumerate(relationships_train.items()) 108 | if i not in val_idx 109 | } 110 | 111 | # TODO: hack to work with small sample of data for tox 112 | if os.path.isdir("data/VRD/sg_dataset/samples"): 113 | # pass in list of images as keys_list 114 | keys_list = os.listdir("data/VRD/sg_dataset/samples") 115 | test_df = vrd_to_pandas( 116 | relationships_test, 117 | objects, 118 | predicates, 119 | list_of_predicates=semantic_predicates, 120 | keys_list=keys_list, 121 | ) 122 | return test_df, test_df, test_df 123 | elif os.path.isdir("data/VRD/sg_dataset/sg_train_images"): 124 | train_df = vrd_to_pandas( 125 | relationships_train, 126 | objects, 127 | predicates, 128 | list_of_predicates=semantic_predicates, 129 | ) 130 | train_df["labels"] = -1 * np.ones(len(train_df)) 131 | valid_df = vrd_to_pandas( 132 | relationships_val, 133 | objects, 134 | predicates, 135 | list_of_predicates=semantic_predicates, 136 | ) 137 | test_df = vrd_to_pandas( 138 | relationships_test, 139 | objects, 140 | predicates, 141 | list_of_predicates=semantic_predicates, 142 | ) 143 | return train_df, valid_df, test_df 144 | -------------------------------------------------------------------------------- /spam/utils.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import subprocess 4 | from collections import OrderedDict 5 | 6 | import numpy as np 7 | import pandas as pd 8 | import torch 9 | import torch.nn as nn 10 | import tensorflow as tf 11 | from sklearn.model_selection import train_test_split 12 | 13 | from snorkel.classification.data import DictDataset, DictDataLoader 14 | 15 | 16 | def load_spam_dataset(load_train_labels: bool = False, split_dev_valid: bool = False): 17 | if os.path.basename(os.getcwd()) == "snorkel-tutorials": 18 | os.chdir("spam") 19 | try: 20 | subprocess.run(["bash", "download_data.sh"], check=True, stderr=subprocess.PIPE) 21 | except subprocess.CalledProcessError as e: 22 | print(e.stderr.decode()) 23 | raise e 24 | filenames = sorted(glob.glob("data/Youtube*.csv")) 25 | 26 | dfs = [] 27 | for i, filename in enumerate(filenames, start=1): 28 | df = pd.read_csv(filename) 29 | # Lowercase column names 30 | df.columns = map(str.lower, df.columns) 31 | # Remove comment_id field 32 | df = df.drop("comment_id", axis=1) 33 | # Add field indicating source video 34 | df["video"] = [i] * len(df) 35 | # Rename fields 36 | df = df.rename(columns={"class": "label", "content": "text"}) 37 | # Shuffle order 38 | df = df.sample(frac=1, random_state=123).reset_index(drop=True) 39 | dfs.append(df) 40 | 41 | df_train = pd.concat(dfs[:4]) 42 | df_dev = df_train.sample(100, random_state=123) 43 | 44 | if not load_train_labels: 45 | df_train["label"] = np.ones(len(df_train["label"])) * -1 46 | df_valid_test = dfs[4] 47 | df_valid, df_test = train_test_split( 48 | df_valid_test, test_size=250, random_state=123, stratify=df_valid_test.label 49 | ) 50 | 51 | if split_dev_valid: 52 | return df_train, df_dev, df_valid, df_test 53 | else: 54 | return df_train, df_test 55 | 56 | 57 | def get_keras_logreg(input_dim, output_dim=2): 58 | model = tf.keras.Sequential() 59 | if output_dim == 1: 60 | loss = "binary_crossentropy" 61 | activation = tf.nn.sigmoid 62 | else: 63 | loss = "categorical_crossentropy" 64 | activation = tf.nn.softmax 65 | dense = tf.keras.layers.Dense( 66 | units=output_dim, 67 | input_dim=input_dim, 68 | activation=activation, 69 | kernel_regularizer=tf.keras.regularizers.l2(0.001), 70 | ) 71 | model.add(dense) 72 | opt = tf.keras.optimizers.Adam(lr=0.01) 73 | model.compile(optimizer=opt, loss=loss, metrics=["accuracy"]) 74 | return model 75 | 76 | 77 | def get_keras_lstm(num_buckets, embed_dim=16, rnn_state_size=64): 78 | lstm_model = tf.keras.Sequential() 79 | lstm_model.add(tf.keras.layers.Embedding(num_buckets, embed_dim)) 80 | lstm_model.add(tf.keras.layers.LSTM(rnn_state_size, activation=tf.nn.relu)) 81 | lstm_model.add(tf.keras.layers.Dense(1, activation=tf.nn.sigmoid)) 82 | lstm_model.compile("Adagrad", "binary_crossentropy", metrics=["accuracy"]) 83 | return lstm_model 84 | 85 | 86 | def get_keras_early_stopping(patience=10, monitor="val_acc"): 87 | """Stops training if monitor value doesn't exceed the current max value after patience num of epochs""" 88 | return tf.keras.callbacks.EarlyStopping( 89 | monitor=monitor, patience=patience, verbose=1, restore_best_weights=True 90 | ) 91 | 92 | 93 | def map_pad_or_truncate(string, max_length=30, num_buckets=30000): 94 | """Tokenize text, pad or truncate to get max_length, and hash tokens.""" 95 | ids = tf.keras.preprocessing.text.hashing_trick( 96 | string, n=num_buckets, hash_function="md5" 97 | ) 98 | return ids[:max_length] + [0] * (max_length - len(ids)) 99 | 100 | 101 | def featurize_df_tokens(df): 102 | return np.array(list(map(map_pad_or_truncate, df.text))) 103 | 104 | 105 | def preview_tfs(df, tfs): 106 | transformed_examples = [] 107 | for f in tfs: 108 | for i, row in df.sample(frac=1, random_state=2).iterrows(): 109 | transformed_or_none = f(row) 110 | # If TF returned a transformed example, record it in dict and move to next TF. 111 | if transformed_or_none is not None: 112 | transformed_examples.append( 113 | OrderedDict( 114 | { 115 | "TF Name": f.name, 116 | "Original Text": row.text, 117 | "Transformed Text": transformed_or_none.text, 118 | } 119 | ) 120 | ) 121 | break 122 | return pd.DataFrame(transformed_examples) 123 | 124 | 125 | def df_to_features(vectorizer, df, split): 126 | """Convert pandas DataFrame containing spam data to bag-of-words PyTorch features.""" 127 | words = [row.text for i, row in df.iterrows()] 128 | 129 | if split == "train": 130 | feats = vectorizer.fit_transform(words) 131 | else: 132 | feats = vectorizer.transform(words) 133 | X = feats.todense() 134 | Y = df["label"].values 135 | return X, Y 136 | 137 | 138 | def create_dict_dataloader(X, Y, split, **kwargs): 139 | """Create a DictDataLoader for bag-of-words features.""" 140 | ds = DictDataset.from_tensors(torch.FloatTensor(X), torch.LongTensor(Y), split) 141 | return DictDataLoader(ds, **kwargs) 142 | 143 | 144 | def get_pytorch_mlp(hidden_dim, num_layers): 145 | layers = [] 146 | for _ in range(num_layers): 147 | layers.extend([nn.Linear(hidden_dim, hidden_dim), nn.ReLU()]) 148 | return nn.Sequential(*layers) 149 | -------------------------------------------------------------------------------- /visual_relation/model.py: -------------------------------------------------------------------------------- 1 | import csv 2 | from pathlib import Path 3 | from typing import Tuple 4 | 5 | import numpy as np 6 | import pandas 7 | import torch 8 | import torch.nn as nn 9 | from PIL import Image 10 | from torchvision import transforms 11 | 12 | from snorkel.analysis import Scorer 13 | from snorkel.classification import DictDataset, MultitaskClassifier, Operation, Task 14 | from snorkel.classification.data import XDict, YDict 15 | 16 | 17 | def union(bbox1, bbox2): 18 | """Create the union of the two bboxes. 19 | 20 | Parameters 21 | ---------- 22 | bbox1 23 | Coordinates of first bounding box 24 | bbox2 25 | Coordinates of second bounding box 26 | 27 | Returns 28 | ------- 29 | [y0, y1, x0, x1] 30 | Coordinates of union of input bounding boxes 31 | 32 | """ 33 | y0 = min(bbox1[0], bbox2[0]) 34 | y1 = max(bbox1[1], bbox2[1]) 35 | x0 = min(bbox1[2], bbox2[2]) 36 | x1 = max(bbox1[3], bbox2[3]) 37 | return [y0, y1, x0, x1] 38 | 39 | 40 | def crop_img_arr(img_arr, bbox): 41 | """Crop bounding box from image. 42 | 43 | Parameters 44 | ---------- 45 | img_arr 46 | Image in array format 47 | bbox 48 | Coordinates of bounding box to crop 49 | 50 | Returns 51 | ------- 52 | img_arr 53 | Cropped image 54 | 55 | """ 56 | return img_arr[bbox[0] : bbox[1], bbox[2] : bbox[3], :] 57 | 58 | 59 | class SceneGraphDataset(DictDataset): 60 | """Dataloader for Scene Graph Dataset.""" 61 | 62 | def __init__( 63 | self, 64 | name: str, 65 | split: str, 66 | image_dir: str, 67 | df: pandas.DataFrame, 68 | image_size=224, 69 | ) -> None: 70 | self.image_dir = Path(image_dir) 71 | X_dict = { 72 | "img_fn": df["source_img"].tolist(), 73 | "obj_bbox": df["object_bbox"].tolist(), 74 | "sub_bbox": df["subject_bbox"].tolist(), 75 | "obj_category": df["object_category"].tolist(), 76 | "sub_category": df["subject_category"].tolist(), 77 | } 78 | Y_dict = { 79 | "visual_relation_task": torch.LongTensor(df["label"].to_numpy()) 80 | } # change to take in the rounded train labels 81 | super(SceneGraphDataset, self).__init__(name, split, X_dict, Y_dict) 82 | 83 | # define standard set of transformations to apply to each image 84 | self.transform = transforms.Compose( 85 | [ 86 | transforms.Resize((image_size, image_size)), 87 | transforms.ToTensor(), 88 | transforms.Normalize( 89 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 90 | ), 91 | ] 92 | ) 93 | 94 | def __getitem__(self, index: int) -> Tuple[XDict, YDict]: 95 | img_fn = self.X_dict["img_fn"][index] 96 | img_arr = np.array(Image.open(self.image_dir / img_fn)) 97 | 98 | obj_bbox = self.X_dict["obj_bbox"][index] 99 | sub_bbox = self.X_dict["sub_bbox"][index] 100 | obj_category = self.X_dict["obj_category"][index] 101 | sub_category = self.X_dict["sub_category"][index] 102 | 103 | # compute crops 104 | obj_crop = crop_img_arr(img_arr, obj_bbox) 105 | sub_crop = crop_img_arr(img_arr, sub_bbox) 106 | union_crop = crop_img_arr(img_arr, union(obj_bbox, sub_bbox)) 107 | 108 | # transform each crop 109 | x_dict = { 110 | "obj_crop": self.transform(Image.fromarray(obj_crop)), 111 | "sub_crop": self.transform(Image.fromarray(sub_crop)), 112 | "union_crop": self.transform(Image.fromarray(union_crop)), 113 | "obj_category": obj_category, 114 | "sub_category": sub_category, 115 | } 116 | 117 | y_dict = {name: label[index] for name, label in self.Y_dict.items()} 118 | return x_dict, y_dict 119 | 120 | def __len__(self): 121 | return len(self.X_dict["img_fn"]) 122 | 123 | 124 | class WordEmb(nn.Module): 125 | """Extract and concat word embeddings for obj and sub categories.""" 126 | 127 | def __init__(self, glove_fn="data/glove/glove.6B.100d.txt"): 128 | super(WordEmb, self).__init__() 129 | 130 | self.word_embs = pandas.read_csv( 131 | glove_fn, sep=" ", index_col=0, header=None, quoting=csv.QUOTE_NONE 132 | ) 133 | 134 | def _get_wordvec(self, word): 135 | return self.word_embs.loc[word].as_matrix() 136 | 137 | def forward(self, obj_category, sub_category): 138 | obj_emb = self._get_wordvec(obj_category) 139 | sub_emb = self._get_wordvec(sub_category) 140 | embs = np.concatenate([obj_emb, sub_emb], axis=1) 141 | return torch.FloatTensor(embs) 142 | 143 | 144 | # Classes and helper functions for defining classifier 145 | def init_fc(fc): 146 | torch.nn.init.xavier_uniform_(fc.weight) 147 | fc.bias.data.fill_(0.01) 148 | 149 | 150 | class FlatConcat(nn.Module): 151 | """Module that flattens and concatenates features""" 152 | 153 | def forward(self, *inputs): 154 | return torch.cat([input.view(input.size(0), -1) for input in inputs], dim=1) 155 | 156 | 157 | # Helper functions to geenerate operations 158 | def get_op_sequence(): 159 | # define feature extractors for each of the (union, subject, and object) image crops 160 | union_feat_op = Operation( 161 | name="union_feat_op", 162 | module_name="feat_extractor", 163 | inputs=[("_input_", "union_crop")], 164 | ) 165 | 166 | sub_feat_op = Operation( 167 | name="sub_feat_op", 168 | module_name="feat_extractor", 169 | inputs=[("_input_", "sub_crop")], 170 | ) 171 | 172 | obj_feat_op = Operation( 173 | name="obj_feat_op", 174 | module_name="feat_extractor", 175 | inputs=[("_input_", "obj_crop")], 176 | ) 177 | 178 | # define an operation to extract word embeddings for subject and object categories 179 | word_emb_op = Operation( 180 | name="word_emb_op", 181 | module_name="word_emb", 182 | inputs=[("_input_", "sub_category"), ("_input_", "obj_category")], 183 | ) 184 | 185 | # define an operation to concatenate image features and word embeddings 186 | concat_op = Operation( 187 | name="concat_op", 188 | module_name="feat_concat", 189 | inputs=["obj_feat_op", "sub_feat_op", "union_feat_op", "word_emb_op"], 190 | ) 191 | 192 | # define an operation to make a prediction over all concatenated features 193 | prediction_op = Operation( 194 | name="head_op", module_name="prediction_head", inputs=["concat_op"] 195 | ) 196 | 197 | return [ 198 | sub_feat_op, 199 | obj_feat_op, 200 | union_feat_op, 201 | word_emb_op, 202 | concat_op, 203 | prediction_op, 204 | ] 205 | 206 | 207 | # Create model from pre loaded resnet cnn. 208 | def create_model(resnet_cnn): 209 | # freeze the resnet weights 210 | for param in resnet_cnn.parameters(): 211 | param.requires_grad = False 212 | 213 | # define input features 214 | in_features = resnet_cnn.fc.in_features 215 | feature_extractor = nn.Sequential(*list(resnet_cnn.children())[:-1]) 216 | 217 | # initialize FC layer: maps 3 sets of image features to class logits 218 | WEMB_SIZE = 100 219 | fc = nn.Linear(in_features * 3 + 2 * WEMB_SIZE, 3) 220 | init_fc(fc) 221 | 222 | # define layers 223 | module_pool = nn.ModuleDict( 224 | { 225 | "feat_extractor": feature_extractor, 226 | "prediction_head": fc, 227 | "feat_concat": FlatConcat(), 228 | "word_emb": WordEmb(), 229 | } 230 | ) 231 | 232 | # define task flow through modules 233 | op_sequence = get_op_sequence() 234 | pred_cls_task = Task( 235 | name="visual_relation_task", 236 | module_pool=module_pool, 237 | op_sequence=op_sequence, 238 | scorer=Scorer(metrics=["f1_micro"]), 239 | ) 240 | return MultitaskClassifier([pred_cls_task]) 241 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to Snorkel Tutorials 2 | 3 | We love contributors, so first and foremost, thank you! 4 | We're actively working on our contributing guidelines, so this document is subject to change. 5 | First things first: we adhere to the 6 | [Contributor Covenant Code of Conduct](http://contributor-covenant.org/version/1/4/), 7 | so please read through it before contributing. 8 | 9 | ### Contents 10 | 11 | * [Types of Tutorials](#types-of-tutorials) 12 | * [Dev Setup](#dev-setup) 13 | * [Making Changes to an Existing Tutorial](#making-changes-to-an-existing-tutorial) 14 | * [Adding a New Tutorial](#adding-a-new-tutorial) 15 | * [Testing Changes Locally](#testing-changes-locally) 16 | * [Previewing Changes to the Website](#previewing-changes-to-the-website) 17 | 18 | 19 | ## Types of Tutorials 20 | 21 | Currently, we have notebook-based tutorials and script-based tutorials. 22 | Both types are referenced in this guide. 23 | Notebook-based tutorials act as walkthroughs of concepts 24 | (the [`spam` tutorial](https://github.com/snorkel-team/snorkel-tutorials/tree/master/spam) is a good example). 25 | Script-based tutorials act more as examples of how to deploy certain Snorkel functionality 26 | (the [`drybell` tutorial](https://github.com/snorkel-team/snorkel-tutorials/tree/master/drybell) is a good example). 27 | We default to using notebook-based tutorials. 28 | 29 | ## Dev Setup 30 | 31 | For dev setup, you will need to install [`tox`](https://tox.readthedocs.io), and set up a virtualenv with all the requirements. 32 | For example, if you use `pip`, and want to work on the `spam` tutorial: 33 | 34 | ```bash 35 | python3 -m pip install -U 'tox>=3.13.0,<4.0.0' 36 | python3 -m pip install --upgrade virtualenv 37 | virtualenv -p python3 .env 38 | source .env/bin/activate 39 | 40 | python3 -m pip install -r requirements.txt 41 | python3 -m pip install -r spam/requirements.txt # Change based on tutorial. 42 | ``` 43 | 44 | Start jupyter from the virtualenv to make sure the kernel has all the required dependencies. 45 | 46 | ## Making Changes to an Existing Tutorial 47 | 48 | First, we recommend [posting an issue](https://github.com/snorkel-team/snorkel-tutorials/issues/new) 49 | describing the improvement or fix you want to make. 50 | Once you've worked out details with the maintainers, follow these general steps: 51 | 52 | 1. Make your changes to the source files 53 | * For notebook-based tutorials, we recommend making changes to the `.py` version 54 | then syncing changes with `tox -e my_tutorial_dir -- sync`. 55 | Alternatively, if you have already run all the cells in your browser, you can select 56 | `File` → `Jupytext` → `Pair Notebook with percent Script` to save the 57 | outputs directly to the notebook version. 58 | After saving, unpair the notebook with 59 | `File` → `Jupytext` → `Unpair notebook` so jupyter does not 60 | keep updating the notebook when all cells haven't been run. 61 | * For script-based tutorials, just make the changes as you normally would. 62 | 1. [Test your changes locally](#testing-changes-locally) 63 | 1. Submit a PR! 64 | 65 | ## Adding a New Tutorial 66 | 67 | Before adding a new tutorial, we recommend posting a proposal to the 68 | [Snorkel community forum on Spectrum](https://spectrum.chat/snorkel/tutorials?tab=posts). 69 | Once you've worked out details with the maintainers, follow these general steps: 70 | 71 | 1. Create a new top-level directory (e.g. `my_tutorial_dir`) 72 | 1. Add a tutorial configuration file to your tutorial directory 73 | * For notebook-based tutorials, add a file called `.notebooks` to your tutorial directory 74 | and add the base name of each tutorial script/notebook pair (e.g. `my_tutorial`) as a 75 | separate line in `.notebooks`. 76 | See `spam` for an [example](https://github.com/snorkel-team/snorkel-tutorials/blob/master/spam/.notebooks). 77 | * For script-based tutorials, add a file called `.scripts` to your tutorial directory 78 | and add the file name of each tutorial script (e.g. `my_tutorial.py`) as a 79 | separate line in `.scripts`. 80 | See `drybell` for an [example](https://github.com/snorkel-team/snorkel-tutorials/blob/master/drybell/.scripts). 81 | 1. Add a `requirements.txt` to your directory if additional ones are needed 82 | 1. Add a command to `[testenv]` in `tox.ini` by copying `spam` and add the `requirements.txt` file if necessary. 83 | Also add the command name to `envlist`. 84 | 1. Write your tutorial! 85 | * For notebook-based tutorials, write your tutorial either as a Python script (e.g. `my_tutorial_dir/my_tutorial.py`) in [Jupytext percent format](https://gist.github.com/mwouts/91f3e1262871cdaa6d35394cd14f9bdc) or a Jupyter notebook. 86 | * Run `tox -e my_tutorial_dir -- sync` to generate a notebook version from the Python script version 87 | (or if you have run all cells, you can select 88 | `File` → `Jupytext` → `Pair Notebook with percent Script` to 89 | save the outputs directly to the notebook version, and then unpair it 90 | with `File` → `Jupytext` → `Unpair notebook` so jupyter does not 91 | keep updating the notebook when all cells haven't been run). 92 | Do this to update the notebook whenever changes are made to the tutorial script. 93 | * Run `tox -e my_tutorial_dir -- sync --py` to generate a Python script version from the notebook version. Run this command to update when changes are made to the tutorial notebook. 94 | * For script-based tutorials, write your tutoral as a Python script. 95 | 1. [Test your changes locally](#testing-changes-locally) 96 | 1. Submit a PR! Make sure to include a reference to the Spectrum planning thread. 97 | 98 | ## Testing Changes Locally 99 | 100 | ### Testing changes to tutorials 101 | 102 | You can test changes to a specific tutorial by running `tox -e my_tutorial` where `my_tutoral` is 103 | replaced by the corresponding environment name in `tox.ini`. 104 | For scripts and notebooks, this will check that they execute without erroring. 105 | For notebooks only, this will also check that any URLs in Markdown cells are reachable and that 106 | the `.ipynb` versions match the `.py` versions. 107 | Travis will also always run `tox -e style` to check code style and formatting, so you sould always 108 | run this locally as well. 109 | Running `tox` on its own will test all tutorials, which can be **extremely slow** since some tutorials 110 | (like `recsys`) take a long time to run with a full dataset. 111 | 112 | ### Other `tox`-related commands 113 | 114 | * To fix code formatting issues, run `tox -e fix`. 115 | * You might need to update packages for an environmnet (for example, if you update `requirements.txt` 116 | dependencies). Use the `-r` command for this. For example, you can run `tox -e spam -r -- sync` to 117 | rebuild the `spam` environment and then run the `sync` script. 118 | * Travis uses the `get_tox_envs.py` script to figure out which `tox` environments it needs to run to 119 | test a PR. Once you've commited your changes, you can preview the environments that Travis will execute 120 | by running `python3 scripts/get_tox_envs.py --plan`. 121 | 122 | ## Previewing Changes to the Website 123 | 124 | All of the tutorials listed in `.web.yml` are rendered on our [website](https://snorkel.org/use-cases/). 125 | Details on configuring with `.web.yml` are in that file. 126 | In order to display tutorials as webpages, we convert the `.ipynb` versions to Markdown. 127 | To generate all Markdown files listed in `.web.yml`, use `tox -e markdown`. 128 | This will generate files in the `build` directory. 129 | 130 | You can prevent cells from being rendered in Markdown by adding `{"tag": ["md-exclude"]}` 131 | to the cell header in the `.py` file. 132 | This is useful for confusing setup cells or cells with difficult-to-render outputs. 133 | For example: 134 | 135 | ```python 136 | # %% {"tag": ["md-exclude"]} 137 | command.do_not_show() 138 | this_line.will_not_appear() 139 | ``` 140 | 141 | You can also prevent cells from rendering output in Markdown by adding 142 | `{"tag": ["md-exclude-output"]}` to the cell header in the `.py` file. 143 | This is useful for cells that display warning messages and other confusing stuff. 144 | For example: 145 | 146 | ```python 147 | # %% {"tag": ["md-exclude-output"]} 148 | command.will_show() 149 | this_line.will_appear() 150 | print(my_object) # The output will not show 151 | ``` 152 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Snorkel Tutorials 2 | ![Snorkel Version](https://img.shields.io/badge/snorkel-0.9.5-65baf6) 3 | ![Python Version](https://img.shields.io/badge/python-3.6%20%7C%203.7-blue) 4 | [![build](https://travis-ci.com/snorkel-team/snorkel-tutorials.svg?branch=master)](https://travis-ci.com/snorkel-team/snorkel-tutorials?branch=master) 5 | [![license](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) 6 | 7 | A collection of tutorials for [Snorkel](https://github.com/snorkel-team/snorkel). 8 | For more information, visit the [Snorkel website](https://snorkel.org). 9 | 10 | ## Contents 11 | * [Tutorials](#tutorials) 12 | * [Getting Started](#getting-started) 13 | * [Contributing to Snorkel Tutorials](#contributing-to-snorkel-tutorials) 14 | 15 | 16 | ## Tutorials 17 | We recommend that all users **start with the [Getting Started](https://snorkel.org/get-started/) tutorial hosted on the Snorkel website** for a gentle introduction to the concepts and classes of Snorkel. 18 | All other tutorials assume that you have already completed that tutorial and are familiar with its concepts. 19 | After that, the `spam` directory contains three tutorials introducing labeling functions, transformation functions, and slicing functions, respectively. 20 | The rest of the tutorials explore other tasks, techniques, and integrations. 21 | 22 | The Snorkel tutorials are grouped by application: 23 | * [`spam`](https://github.com/snorkel-team/snorkel-tutorials/blob/master/spam/README.md): Is this YouTube comment spam? 24 | * [`spouse`](https://github.com/snorkel-team/snorkel-tutorials/blob/master/spouse/README.md): Does this sentence imply that the two marked people are spouses? 25 | * [`visual_relation`](https://github.com/snorkel-team/snorkel-tutorials/blob/master/visual_relation/README.md): Is object A riding object B in the image, carrying it, or neither? 26 | * [`crowdsourcing`](https://github.com/snorkel-team/snorkel-tutorials/blob/master/crowdsourcing/README.md): Is this tweet about the weather expressing a positive, negative or neutral sentiment? 27 | * [`multitask`](https://github.com/snorkel-team/snorkel-tutorials/blob/master/multitask/README.md) (Multi-Task Learning): A synthetic task demonstrating the native Snorkel multi-task classifier API 28 | * [`recsys`](https://github.com/snorkel-team/snorkel-tutorials/blob/master/recsys/README.md) (Recommender Systems): Will this user read and like this book? 29 | * [`drybell`](https://github.com/snorkel-team/snorkel-tutorials/blob/master/drybell/README.md): Is a celebrity mentioned in this news article? 30 | 31 | Here we provide an index pointing to different available tutorials by their task type, techniques, and integrations. 32 | * Task 33 | * Text Classification (Text): `spam`, `crowdsourcing`, `drybell` 34 | * Relation Extraction (Text): `spouse` 35 | * Visual Relationship Detection (Image): `visual_relation` 36 | * Recommender Systems: `recsys` 37 | * Techniques 38 | * Labeling with Labeling Functions (LFs): `spam`, `spouse`, `visual_relation`, `crowdsourcing` 39 | * Augmentation with Transformation Functions (TFs): `spam` 40 | * Monitoring with Slicing Functions (SFs): `spam` 41 | * Using Crowdworker Labels: `crowdsourcing` 42 | * Multi-Task Learning (MTL): `multitask`, `spam` 43 | * Integrations 44 | * TensorFlow/Keras: `spam`, `spouse` 45 | * Scikit-learn: `spam`, `crowdsourcing` 46 | * PyTorch: `multitask`, `visual_relation` 47 | * Dask: `drybell` 48 | * Spark: `drybell` 49 | 50 | 51 | ## Getting Started 52 | 53 | Step one is cloning this repo. 54 | 55 | ```bash 56 | git clone https://github.com/snorkel-team/snorkel-tutorials.git 57 | cd snorkel-tutorials 58 | ``` 59 | 60 | As with Snorkel, our tutorials require Python 3.6+. 61 | If you're looking to quickly get started with a tutorial, we recommend using 62 | our [Docker setup](#docker). 63 | If you want to install things yourself using `pip` or `conda`, you can follow 64 | our [installation steps](#install) below instead. 65 | 66 | ***Snorkel version*** 67 | 68 | This tutorials repo is pinned to a specific version of the Snorkel library, 69 | which is specified in the 70 | [requirements file](https://github.com/snorkel-team/snorkel-tutorials/blob/master/requirements.txt). 71 | Note that this will likely not be up to date with the `master` branch in 72 | the main Snorkel repo. 73 | We recommend using virtual environments or Docker containers to run the 74 | tutorials, so check out the details below. 75 | 76 | ***A quick note for Windows users*** 77 | 78 | If you're using Windows, we highly recommend using the [Docker setup](#docker) 79 | or the [Linux subsystem](https://docs.microsoft.com/en-us/windows/wsl/faq). 80 | It can be tricky to get the installation right using application-specific shells 81 | (e.g. the `conda` shell). 82 | Additionally, the shell scripts included in this repo (such as those for 83 | downloading datasets) use \*nix-style commands. 84 | 85 | 86 | ### Running with Docker 87 | 88 | We've included a Docker setup for our tutorials to make setup easy. 89 | First, make sure you have [Docker installed](https://docs.docker.com/install/) on your machine. 90 | To build and run a Docker image for a tutorial, use `scripts/docker_launch.py` with the `--build` flag. 91 | For example, run the following for the `spam` tutorial: 92 | 93 | ```bash 94 | python3 scripts/docker_launch.py spam --build 95 | ``` 96 | 97 | Building a Docker image from scratch can take anywhere between 5 and 98 | 30 minutes depending on the machine you're using. 99 | We're working on making prebuilt images available via DockerHub. 100 | 101 | Once the image has been built, a Jupyter notebook server will be available 102 | on port 8888 (you can change the port with the `--port` command line option) 103 | and print out a link you can follow to access the browser interface. 104 | In your browser, open a `.ipynb` file you would like to run — 105 | such as `01_spam_tutorial.ipynb` — and execute the cells in sequence. 106 | 107 | Once you've built a tutorial-specific image for the first time, 108 | you can run it without the `--build` flag: 109 | 110 | ```bash 111 | python3 scripts/docker_launch.py spam 112 | ``` 113 | 114 | ### Installing yourself 115 | 116 | Running a tutorial has three required steps if you're installing yourself: 117 | 118 | 1. Installing repo-wide requirements 119 | 1. Installing tutorial-specific requirements 120 | 1. Launching a Jupyter notebook server or executing as a script 121 | 122 | We recommend installing requirements in a virtual environment using [`virtualenv`](https://virtualenv.pypa.io/en/latest/) or [`conda`](https://docs.conda.io/en/latest/). 123 | 124 | The following example commands show you how to install the requirements for the 125 | `spam` tutorial, then launch a notebook server to run the tutorial. 126 | To run a different tutorial, simply replace `spam` with the desired directory. 127 | 128 |
Installing with pip 129 |

130 | 131 | These commands assume that your Python version is 3.6+ and that the Python 3 132 | version of `pip` is available as `pip3`. 133 | It may be available as `pip` depending on how your system is configured. 134 | 135 | ```bash 136 | # [OPTIONAL] Activate a virtual environment 137 | pip3 install --upgrade virtualenv 138 | virtualenv -p python3 .envspam 139 | source .envspam/bin/activate 140 | 141 | # Install requirements (both shared and tutorial-specific) 142 | pip3 install -r requirements.txt 143 | pip3 install -r spam/requirements.txt 144 | 145 | # Launch the Jupyter notebook interface (making sure the right virtual environment is used) 146 | .envspam/bin/jupyter notebook spam 147 | ``` 148 | 149 |

150 |
151 | 152 | 153 |
Installing with conda 154 |

155 | 156 | These commands assume that your conda installation is Python 3.6+. 157 | 158 | ```bash 159 | # [OPTIONAL] Activate a virtual environment 160 | conda create --yes -n spam python=3.6 161 | conda activate spam 162 | 163 | # Install requirements (both shared and tutorial-specific) 164 | pip install environment_kernels 165 | # We specify PyTorch here to ensure compatibility, but it may not be necessary. 166 | conda install pytorch==1.1.0 -c pytorch 167 | conda install snorkel==0.9.5 -c conda-forge 168 | pip install -r spam/requirements.txt 169 | 170 | # Launch the Jupyter notebook interface 171 | jupyter notebook spam 172 | ``` 173 | 174 | Make sure to select the right kernel (`conda_spam`) when running the jupyter notebook. 175 | 176 |

177 |
178 | 179 | Then in the browser tab that opens, navigate to a `.ipynb` file you would like 180 | to run — such as `01_spam_tutorial.ipynb` — and execute the 181 | cells in sequence. 182 | 183 | Alternatively, you can run the tutorial as a script by calling `python3` on the corresponding `.py` file directly (e.g. `python3 spam/01_spam_tutorial.py`). 184 | The `.py` source files are written in [Jupytext `percent` format](https://jupytext.readthedocs.io/en/latest/), and contain the same content as the notebooks. 185 | 186 | 187 | ## Contributing to Snorkel Tutorials 188 | 189 | If you're interested in improving existing tutorials or contributing new tutorials, 190 | check out our [contributing guidelines](./CONTRIBUTING.md). 191 | -------------------------------------------------------------------------------- /visual_relation/visual_relation_tutorial.py: -------------------------------------------------------------------------------- 1 | # %% [markdown] 2 | # # Visual Relationship Detection 3 | # 4 | # In this tutorial, we focus on the task of classifying visual relationships between objects in an image. For any given image, there might be many such relationships, defined formally as a `subject object` (e.g. `person bike`). As an example, in the relationship `man riding bicycle`), "man" and "bicycle" are the subject and object, respectively, and "riding" is the relationship predicate. 5 | # 6 | # ![Visual Relationships](https://cs.stanford.edu/people/ranjaykrishna/vrd/dataset.png) 7 | # 8 | # In the examples of the relationships shown above, the red box represents the _subject_ while the green box represents the _object_. The _predicate_ (e.g. kick) denotes what relationship connects the subject and the object. 9 | # 10 | # For the purpose of this tutorial, we operate over the [Visual Relationship Detection (VRD) dataset](https://cs.stanford.edu/people/ranjaykrishna/vrd/) and focus on action relationships. We define our classification task as **identifying which of three relationships holds between the objects represented by a pair of bounding boxes.** 11 | 12 | # %% {"tags": ["md-exclude"]} 13 | import os 14 | 15 | if os.path.basename(os.getcwd()) == "snorkel-tutorials": 16 | os.chdir("visual_relation") 17 | 18 | # %% [markdown] 19 | # ### 1. Load Dataset 20 | # We load the VRD dataset and filter images with at least one action predicate in it, since these are more difficult to classify than geometric relationships like `above` or `next to`. We load the train, valid, and test sets as Pandas `DataFrame` objects with the following fields: 21 | # - `label`: The relationship between the objects. 0: `RIDE`, 1: `CARRY`, 2: `OTHER` action predicates 22 | # - `object_bbox`: coordinates of the bounding box for the object `[ymin, ymax, xmin, xmax]` 23 | # - `object_category`: category of the object 24 | # - `source_img`: filename for the corresponding image the relationship is in 25 | # - `subject_bbox`: coordinates of the bounding box for the object `[ymin, ymax, xmin, xmax]` 26 | # - `subject_category`: category of the subject 27 | 28 | # %% [markdown] 29 | # If you are running this notebook for the first time, it will take ~15 mins to download all the required sample data. 30 | # 31 | # The sampled version of the dataset **uses the same 26 data points across the train, dev, and test sets. 32 | # This setting is meant to demonstrate quickly how Snorkel works with this task, not to demonstrate performance.** 33 | 34 | # %% 35 | from utils import load_vrd_data 36 | 37 | # setting sample=False will take ~3 hours to run (downloads full VRD dataset) 38 | sample = True 39 | is_test = os.environ.get("TRAVIS") == "true" or os.environ.get("IS_TEST") == "true" 40 | df_train, df_valid, df_test = load_vrd_data(sample, is_test) 41 | 42 | print("Train Relationships: ", len(df_train)) 43 | print("Dev Relationships: ", len(df_valid)) 44 | print("Test Relationships: ", len(df_test)) 45 | 46 | # %% [markdown] 47 | # Note that the training `DataFrame` will have a labels field with all -1s. This denotes the lack of labels for that particular dataset. In this tutorial, we will assign probabilistic labels to the training set by writing labeling functions over attributes of the subject and objects! 48 | 49 | # %% [markdown] 50 | # ## 2. Writing Labeling Functions 51 | # We now write labeling functions to detect what relationship exists between pairs of bounding boxes. To do so, we can encode various intuitions into the labeling functions: 52 | # * _Categorical_ intution: knowledge about the categories of subjects and objects usually involved in these relationships (e.g., `person` is usually the subject for predicates like `ride` and `carry`) 53 | # * _Spatial_ intuition: knowledge about the relative positions of the subject and objects (e.g., subject is usually higher than the object for the predicate `ride`) 54 | 55 | # %% 56 | RIDE = 0 57 | CARRY = 1 58 | OTHER = 2 59 | ABSTAIN = -1 60 | 61 | # %% [markdown] 62 | # We begin with labeling functions that encode categorical intuition: we use knowledge about common subject-object category pairs that are common for `RIDE` and `CARRY` and also knowledge about what subjects or objects are unlikely to be involved in the two relationships. 63 | 64 | # %% 65 | from snorkel.labeling import labeling_function 66 | 67 | # Category-based LFs 68 | @labeling_function() 69 | def lf_ride_object(x): 70 | if x.subject_category == "person": 71 | if x.object_category in [ 72 | "bike", 73 | "snowboard", 74 | "motorcycle", 75 | "horse", 76 | "bus", 77 | "truck", 78 | "elephant", 79 | ]: 80 | return RIDE 81 | return ABSTAIN 82 | 83 | 84 | @labeling_function() 85 | def lf_carry_object(x): 86 | if x.subject_category == "person": 87 | if x.object_category in ["bag", "surfboard", "skis"]: 88 | return CARRY 89 | return ABSTAIN 90 | 91 | 92 | @labeling_function() 93 | def lf_carry_subject(x): 94 | if x.object_category == "person": 95 | if x.subject_category in ["chair", "bike", "snowboard", "motorcycle", "horse"]: 96 | return CARRY 97 | return ABSTAIN 98 | 99 | 100 | @labeling_function() 101 | def lf_not_person(x): 102 | if x.subject_category != "person": 103 | return OTHER 104 | return ABSTAIN 105 | 106 | 107 | # %% [markdown] 108 | # We now encode our spatial intuition, which includes measuring the distance between the bounding boxes and comparing their relative areas. 109 | 110 | # %% 111 | YMIN = 0 112 | YMAX = 1 113 | XMIN = 2 114 | XMAX = 3 115 | 116 | # %% 117 | import numpy as np 118 | 119 | # Distance-based LFs 120 | @labeling_function() 121 | def lf_ydist(x): 122 | if x.subject_bbox[XMAX] < x.object_bbox[XMAX]: 123 | return OTHER 124 | return ABSTAIN 125 | 126 | 127 | @labeling_function() 128 | def lf_dist(x): 129 | if np.linalg.norm(np.array(x.subject_bbox) - np.array(x.object_bbox)) <= 1000: 130 | return OTHER 131 | return ABSTAIN 132 | 133 | 134 | def area(bbox): 135 | return (bbox[YMAX] - bbox[YMIN]) * (bbox[XMAX] - bbox[XMIN]) 136 | 137 | 138 | # Size-based LF 139 | @labeling_function() 140 | def lf_area(x): 141 | if area(x.subject_bbox) / area(x.object_bbox) <= 0.5: 142 | return OTHER 143 | return ABSTAIN 144 | 145 | 146 | # %% [markdown] 147 | # Note that the labeling functions have varying empirical accuracies and coverages. Due to class imbalance in our chosen relationships, labeling functions that label the `OTHER` class have higher coverage than labeling functions for `RIDE` or `CARRY`. This reflects the distribution of classes in the dataset as well. 148 | 149 | # %% {"tags": ["md-exclude-output"]} 150 | from snorkel.labeling import PandasLFApplier 151 | 152 | lfs = [ 153 | lf_ride_object, 154 | lf_carry_object, 155 | lf_carry_subject, 156 | lf_not_person, 157 | lf_ydist, 158 | lf_dist, 159 | lf_area, 160 | ] 161 | 162 | applier = PandasLFApplier(lfs) 163 | L_train = applier.apply(df_train) 164 | L_valid = applier.apply(df_valid) 165 | 166 | # %% 167 | from snorkel.labeling import LFAnalysis 168 | 169 | Y_valid = df_valid.label.values 170 | LFAnalysis(L_valid, lfs).lf_summary(Y_valid) 171 | 172 | # %% [markdown] 173 | # ## 3. Train Label Model 174 | # We now train a multi-class `LabelModel` to assign training labels to the unalabeled training set. 175 | 176 | # %% 177 | from snorkel.labeling.model import LabelModel 178 | 179 | label_model = LabelModel(cardinality=3, verbose=True) 180 | label_model.fit(L_train, seed=123, lr=0.01, log_freq=10, n_epochs=100) 181 | 182 | # %% [markdown] 183 | # We use [F1](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.f1_score.html) Micro average for the multiclass setting, which calculates metrics globally across classes, by counting the total true positives, false negatives and false positives. 184 | 185 | # %% 186 | label_model.score(L_valid, Y_valid, metrics=["f1_micro"]) 187 | 188 | # %% [markdown] 189 | # ## 4. Train a Classifier 190 | # You can then use these training labels to train any standard discriminative model, such as [an off-the-shelf ResNet](https://github.com/KaimingHe/deep-residual-networks), which should learn to generalize beyond the LF's we've developed! 191 | 192 | # %% [markdown] 193 | # #### Create DataLoaders for Classifier 194 | 195 | # %% 196 | from snorkel.classification import DictDataLoader 197 | from model import SceneGraphDataset, create_model 198 | 199 | df_train["labels"] = label_model.predict(L_train) 200 | 201 | if sample: 202 | TRAIN_DIR = "data/VRD/sg_dataset/samples" 203 | else: 204 | TRAIN_DIR = "data/VRD/sg_dataset/sg_train_images" 205 | 206 | dl_train = DictDataLoader( 207 | SceneGraphDataset("train_dataset", "train", TRAIN_DIR, df_train), 208 | batch_size=16, 209 | shuffle=True, 210 | ) 211 | 212 | dl_valid = DictDataLoader( 213 | SceneGraphDataset("valid_dataset", "valid", TRAIN_DIR, df_valid), 214 | batch_size=16, 215 | shuffle=False, 216 | ) 217 | 218 | # %% [markdown] 219 | # #### Define Model Architecture 220 | 221 | # %% 222 | import torchvision.models as models 223 | 224 | # initialize pretrained feature extractor 225 | cnn = models.resnet18(pretrained=True) 226 | model = create_model(cnn) 227 | 228 | # %% [markdown] 229 | # ### Train and Evaluate Model 230 | 231 | # %% {"tags": ["md-exclude-output"]} 232 | from snorkel.classification import Trainer 233 | 234 | trainer = Trainer( 235 | n_epochs=1, # increase for improved performance 236 | lr=1e-3, 237 | checkpointing=True, 238 | checkpointer_config={"checkpoint_dir": "checkpoint"}, 239 | ) 240 | trainer.fit(model, [dl_train]) 241 | 242 | # %% 243 | model.score([dl_valid]) 244 | 245 | # %% [markdown] 246 | # ## Recap 247 | # We have successfully trained a visual relationship detection model! Using categorical and spatial intuition about how objects in a visual relationship interact with each other, we are able to assign high quality training labels to object pairs in the VRD dataset in a multi-class classification setting. 248 | # 249 | # For more on how Snorkel can be used for visual relationship tasks, please see our [ICCV 2019 paper](https://arxiv.org/abs/1904.11622)! 250 | -------------------------------------------------------------------------------- /recsys/utils.py: -------------------------------------------------------------------------------- 1 | import calendar 2 | import gzip 3 | import json 4 | import logging 5 | import os 6 | import pickle 7 | from datetime import datetime 8 | from typing import Any, Dict, List, Optional, Tuple 9 | 10 | import gdown 11 | import numpy as np 12 | import pandas as pd 13 | from sklearn.model_selection import train_test_split 14 | from tensorflow.keras import backend as K 15 | 16 | IS_TEST = os.environ.get("TRAVIS") == "true" or os.environ.get("IS_TEST") == "true" 17 | 18 | YA_BOOKS_URL = "https://drive.google.com/uc?id=1gH7dG4yQzZykTpbHYsrw2nFknjUm0Mol" 19 | YA_INTERACTIONS_URL = "https://drive.google.com/uc?id=1NNX7SWcKahezLFNyiW88QFPAqOAYP5qg" 20 | YA_REVIEWS_URL = "https://drive.google.com/uc?id=1M5iqCZ8a7rZRtsmY5KQ5rYnP9S0bQJVo" 21 | SMALL_DATA_URL = "https://drive.google.com/uc?id=1_UY4xTbk3o0xjGbVllQZC2bBt-WAwyF_" 22 | 23 | BOOK_DATA = "data/goodreads_books_young_adult.json.gz" 24 | INTERACTIONS_DATA = "data/goodreads_interactions_young_adult.json.gz" 25 | REVIEWS_DATA = "data/goodreads_reviews_young_adult.json.gz" 26 | SAMPLE_DATA = "data/sample_data.pkl" 27 | 28 | 29 | # + 30 | def save_small_sample(): 31 | """Load full data, sample, and dump to file..""" 32 | (df_train, df_test, df_dev, df_valid), df_books = download_and_process_data() 33 | df_train = df_train.dropna().sample(frac=0.01) 34 | df_test = df_test.dropna().sample(frac=0.01) 35 | df_dev = df_dev.dropna().sample(frac=0.01) 36 | df_valid = df_valid.dropna().sample(frac=0.01) 37 | df_all = pd.concat([df_train, df_test, df_dev, df_valid], axis=0) 38 | df_books = df_books.merge( 39 | df_all[["book_idx"]].drop_duplicates(), on="book_idx", how="inner" 40 | ) 41 | with open(SAMPLE_DATA, "wb") as f: 42 | pickle.dump(df_train, f) 43 | pickle.dump(df_test, f) 44 | pickle.dump(df_dev, f) 45 | pickle.dump(df_valid, f) 46 | pickle.dump(df_books, f) 47 | 48 | 49 | def load_small_sample(): 50 | """Load sample data.""" 51 | with open(SAMPLE_DATA, "rb") as f: 52 | df_train = pickle.load(f) 53 | df_test = pickle.load(f) 54 | df_dev = pickle.load(f) 55 | df_valid = pickle.load(f) 56 | df_books = pickle.load(f) 57 | return (df_train, df_test, df_dev, df_valid), df_books 58 | 59 | 60 | # - 61 | 62 | 63 | def maybe_download_files(data_dir: str = "data") -> None: 64 | if not os.path.exists(data_dir): 65 | os.makedirs(data_dir, exist_ok=True) 66 | if IS_TEST: 67 | # Sample data pickle 68 | gdown.download(SMALL_DATA_URL, output=SAMPLE_DATA, quiet=None) 69 | else: 70 | # Books 71 | gdown.download(YA_BOOKS_URL, output=BOOK_DATA, quiet=None) 72 | # Interactions 73 | gdown.download(YA_INTERACTIONS_URL, output=INTERACTIONS_DATA, quiet=None) 74 | # Reviews 75 | gdown.download(YA_REVIEWS_URL, output=REVIEWS_DATA, quiet=None) 76 | 77 | 78 | def get_timestamp(date_str: str) -> datetime.timestamp: 79 | month_to_int = dict((v, k) for k, v in enumerate(calendar.month_abbr)) 80 | _, month, day, _, _, year = date_str.split() 81 | dt = datetime(year=int(year), month=month_to_int[month], day=int(day)) 82 | return datetime.timestamp(dt) 83 | 84 | 85 | def load_data( 86 | file_name: str, max_to_load: int = 100, filter_dict: Optional[dict] = None 87 | ) -> List[Dict[str, Any]]: 88 | count = 0 89 | data = [] 90 | filter_dict = filter_dict or {} 91 | with gzip.open(file_name) as fin: 92 | for l in fin: 93 | d = json.loads(l) 94 | for k, v in filter_dict.items(): 95 | if d[k] not in v: 96 | break 97 | else: 98 | count += 1 99 | data.append(d) 100 | if (max_to_load is not None) and (count >= max_to_load): 101 | break 102 | return data 103 | 104 | 105 | def process_books_data( 106 | book_path: str = BOOK_DATA, min_ratings: int = 100, max_ratings: int = 15000 107 | ) -> Tuple[pd.DataFrame, Dict[int, int]]: 108 | books = load_data(book_path, None) 109 | df_books = pd.DataFrame(books) 110 | df_books = df_books[ 111 | [ 112 | "authors", 113 | "average_rating", 114 | "book_id", 115 | "country_code", 116 | "description", 117 | "is_ebook", 118 | "language_code", 119 | "ratings_count", 120 | "similar_books", 121 | "text_reviews_count", 122 | "title", 123 | ] 124 | ] 125 | df_books = df_books.astype( 126 | dict( 127 | average_rating=float, 128 | book_id=int, 129 | is_ebook=bool, 130 | ratings_count=int, 131 | text_reviews_count=int, 132 | ) 133 | ) 134 | # Turns author role dict into list of <= 5 authors for simplicity. 135 | df_books.authors = df_books.authors.map( 136 | lambda l: [pair["author_id"] for pair in l[:5]] 137 | ) 138 | df_books["first_author"] = df_books.authors.map(lambda l: int(l[0])) 139 | 140 | df_books = df_books[ 141 | (df_books.ratings_count >= min_ratings) 142 | & (df_books.ratings_count <= max_ratings) 143 | ] 144 | 145 | book_id_to_idx = {v: i for i, v in enumerate(df_books.book_id)} 146 | df_books["book_idx"] = df_books.book_id.map(book_id_to_idx) 147 | return df_books, book_id_to_idx 148 | 149 | 150 | def process_interactions_data( 151 | book_id_to_idx: Dict[int, int], 152 | interactions_path: str = INTERACTIONS_DATA, 153 | min_user_count: int = 25, 154 | max_user_count: int = 200, 155 | max_to_load: int = 5_000_000, 156 | ) -> Tuple[pd.DataFrame, Dict[int, int]]: 157 | interactions = load_data( 158 | interactions_path, 159 | max_to_load, 160 | dict(book_id=set(map(str, book_id_to_idx.keys()))), 161 | ) 162 | df_interactions = pd.DataFrame(interactions) 163 | df_interactions = df_interactions[ 164 | ["book_id", "is_read", "rating", "review_id", "user_id"] 165 | ] 166 | df_interactions = df_interactions.astype( 167 | dict(book_id=int, is_read=bool, rating=int) 168 | ) 169 | df_interactions["book_idx"] = df_interactions.book_id.map(book_id_to_idx) 170 | user_counts = df_interactions.groupby(["user_id"]).size() 171 | user_mask = (user_counts >= min_user_count) & (user_counts <= max_user_count) 172 | users_filt = user_counts[user_mask].index 173 | user_id_to_idx = {v: i for i, v in enumerate(users_filt)} 174 | df_interactions = df_interactions[ 175 | df_interactions.user_id.isin(set(user_id_to_idx.keys())) 176 | ] 177 | df_interactions["user_idx"] = df_interactions.user_id.map(user_id_to_idx) 178 | return df_interactions, user_id_to_idx 179 | 180 | 181 | def process_reviews_data( 182 | book_id_to_idx: Dict[int, int], 183 | user_id_to_idx: Dict[int, int], 184 | reviews_path: str = REVIEWS_DATA, 185 | ) -> pd.DataFrame: 186 | reviews = load_data( 187 | reviews_path, 188 | None, 189 | dict( 190 | book_id=set(map(str, book_id_to_idx.keys())), 191 | user_id=set(user_id_to_idx.keys()), 192 | ), 193 | ) 194 | df_reviews = pd.DataFrame(reviews) 195 | df_reviews["book_idx"] = df_reviews.book_id.astype("int").map(book_id_to_idx) 196 | df_reviews["user_idx"] = df_reviews.user_id.map(user_id_to_idx) 197 | return df_reviews 198 | 199 | 200 | def split_data(user_idxs, data: pd.DataFrame) -> Tuple[pd.DataFrame, ...]: 201 | user_idxs_train, user_idxs_test = train_test_split(user_idxs, test_size=0.05) 202 | user_idxs_train, user_idxs_dev = train_test_split(user_idxs_train, test_size=0.01) 203 | user_idxs_train, user_idxs_val = train_test_split(user_idxs_train, test_size=0.01) 204 | 205 | data_train = data[data.user_idx.isin(set(user_idxs_train))].drop("rating", axis=1) 206 | data_test = data[data.user_idx.isin(set(user_idxs_test))] 207 | data_dev = data[data.user_idx.isin(set(user_idxs_dev))] 208 | data_val = data[data.user_idx.isin(set(user_idxs_val))] 209 | return data_train, data_test, data_dev, data_val 210 | 211 | 212 | def download_and_process_data() -> Tuple[Tuple[pd.DataFrame, ...], pd.DataFrame]: 213 | logging.info("Downloading raw data") 214 | maybe_download_files() 215 | if IS_TEST: 216 | return load_small_sample() 217 | logging.info("Processing book data") 218 | df_books, book_id_to_idx = process_books_data() 219 | logging.info("Processing interaction data") 220 | df_interactions, user_id_to_idx = process_interactions_data(book_id_to_idx) 221 | df_interactions_nz = df_interactions[df_interactions.rating != 0] 222 | ratings_map = {1: 0, 2: 0, 3: 0, 4: 1, 5: 1} 223 | df_interactions_nz["rating_4_5"] = df_interactions_nz.rating.map(ratings_map) 224 | logging.info("Processing review data") 225 | df_reviews = process_reviews_data(book_id_to_idx, user_id_to_idx) 226 | logging.info("Joining interaction data") 227 | # Compute book_idxs for each user. 228 | user_to_books = ( 229 | df_interactions.groupby("user_idx")["book_idx"] 230 | .apply(tuple) 231 | .reset_index() 232 | .rename(columns={"book_idx": "book_idxs"}) 233 | ) 234 | data = user_to_books.merge(df_interactions_nz, on="user_idx", how="inner")[ 235 | ["user_idx", "book_idxs", "book_idx", "rating_4_5"] 236 | ].merge( 237 | df_reviews[["user_idx", "book_idx", "review_text"]], 238 | on=["user_idx", "book_idx"], 239 | how="left", 240 | ) 241 | data = data.rename(columns={"rating_4_5": "rating"}) 242 | user_idxs = list(user_id_to_idx.values()) 243 | return split_data(user_idxs, data), df_books 244 | 245 | 246 | def recall_batch(y_true: np.ndarray, y_pred: np.ndarray) -> float: 247 | true_positives = K.sum(K.round(y_true * y_pred)) 248 | all_positives = K.sum(y_true) 249 | return true_positives / (all_positives + K.epsilon()) 250 | 251 | 252 | def precision_batch(y_true: np.ndarray, y_pred: np.ndarray) -> float: 253 | true_positives = K.sum(K.round(y_true * y_pred)) 254 | predicted_positives = K.sum(K.round(y_pred)) 255 | return true_positives / (predicted_positives + K.epsilon()) 256 | 257 | 258 | def f1_batch(y_true: np.ndarray, y_pred: np.ndarray) -> float: 259 | prec = precision_batch(y_true, y_pred) 260 | rec = recall_batch(y_true, y_pred) 261 | return 2 * ((prec * rec) / (prec + rec + K.epsilon())) 262 | 263 | 264 | def get_n_epochs() -> int: 265 | return 2 if IS_TEST else 30 266 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /crowdsourcing/crowdsourcing_tutorial.py: -------------------------------------------------------------------------------- 1 | # %% [markdown] 2 | # # Crowdsourcing Tutorial 3 | 4 | # %% [markdown] 5 | # In this tutorial, we'll provide a simple walkthrough of how to use Snorkel in conjunction with crowdsourcing to create a training set for a sentiment analysis task. 6 | # We already have crowdsourced labels for about half of the training dataset. 7 | # The crowdsourced labels are fairly accurate, but do not cover the entire training dataset, nor are they available for the test set or during inference. 8 | # To make up for their lack of training set coverage, we combine crowdsourced labels with heuristic labeling functions to increase the number of training labels we have. 9 | # Like most Snorkel labeling pipelines, we'll use the denoised labels to train a deep learning 10 | # model which can be applied to new, unseen data to automatically make predictions. 11 | 12 | # %% [markdown] 13 | # ## Dataset Details 14 | 15 | # %% [markdown] 16 | # In this tutorial, we'll use the [Weather Sentiment](https://data.world/crowdflower/weather-sentiment) dataset from Figure Eight. 17 | # Our goal is to train a classifier that can label new tweets as expressing either a positive or negative sentiment. 18 | # 19 | # Crowdworkers were asked to label the sentiment of a particular tweet relating to the weather. 20 | # The catch is that 20 crowdworkers graded each tweet, and in many cases crowdworkers assigned conflicting sentiment labels to the same tweet. 21 | # This is a common issue when dealing with crowdsourced labeling workloads. 22 | # 23 | # Label options were positive, negative, or one of three other options saying they weren't sure if it was positive or negative; we use only the positive/negative labels. 24 | # We've also altered the dataset to reflect a realistic crowdsourcing pipeline where only a subset of our available training set has received crowd labels. 25 | # 26 | # We will treat each crowdworker's labels as coming from a single labeling function (LF). 27 | # This will allow us to learn a weight for how much to trust the labels from each crowdworker. 28 | # We will also write a few heuristic labeling functions to cover the data points without crowd labels. 29 | # Snorkel's ability to build high-quality datasets from multiple noisy labeling signals makes it an ideal framework to approach this problem. 30 | 31 | # %% [markdown] 32 | # ## Loading Crowdsourcing Dataset 33 | 34 | # %% [markdown] 35 | # We start by loading our data which has 287 data points in total. 36 | # We take 50 for our development set and 50 for our test set. 37 | # The remaining 187 data points form our training set. 38 | # Since the dataset is already small, we skip using a validation set. 39 | # Note that this very small dataset is primarily used for demonstration purposes here. 40 | # In a real setting, we would expect to have access to many more unlabeled tweets, which could help us to train a higher quality model. 41 | 42 | # %% {"tags": ["md-exclude"]} 43 | import os 44 | 45 | # Make sure we're in the right directory 46 | if os.path.basename(os.getcwd()) == "snorkel-tutorials": 47 | os.chdir("crowdsourcing") 48 | 49 | # %% 50 | from data import load_data 51 | 52 | crowd_labels, df_train, df_dev, df_test = load_data() 53 | Y_dev = df_dev.sentiment.values 54 | Y_test = df_test.sentiment.values 55 | 56 | # %% [markdown] {"tags": ["md-exclude"]} 57 | # First, let's take a look at our development set to get a sense of what the tweets look like. 58 | # We use the following label convention: 0 = Negative, 1 = Positive. 59 | 60 | # %% {"tags": ["md-exclude"]} 61 | import pandas as pd 62 | 63 | # Don't truncate text fields in the display 64 | pd.set_option("display.max_colwidth", 0) 65 | 66 | df_dev.head() 67 | 68 | # %% [markdown] {"tags": ["md-exclude"]} 69 | # Now let's take a look at the crowd labels. 70 | # We'll convert these into labeling functions. 71 | 72 | # %% {"tags": ["md-exclude"]} 73 | crowd_labels.head() 74 | 75 | # %% [markdown] 76 | # ## Writing Labeling Functions 77 | # Each crowdworker can be thought of as a single labeling function, 78 | # as each worker labels a subset of data points, 79 | # and may have errors or conflicting labels with other workers / labeling functions. 80 | # So we create one labeling function per worker. 81 | # We'll simply return the label the worker submitted for a given tweet, and abstain 82 | # if they didn't submit a label for it. 83 | 84 | # %% [markdown] 85 | # ### Crowdworker labeling functions 86 | 87 | # %% 88 | labels_by_annotator = crowd_labels.groupby("worker_id") 89 | worker_dicts = {} 90 | for worker_id in labels_by_annotator.groups: 91 | worker_df = labels_by_annotator.get_group(worker_id)[["label"]] 92 | worker_dicts[worker_id] = dict(zip(worker_df.index, worker_df.label)) 93 | 94 | print("Number of workers:", len(worker_dicts)) 95 | 96 | # %% 97 | from snorkel.labeling import LabelingFunction 98 | 99 | ABSTAIN = -1 100 | 101 | 102 | def worker_lf(x, worker_dict): 103 | return worker_dict.get(x.tweet_id, ABSTAIN) 104 | 105 | 106 | def make_worker_lf(worker_id): 107 | worker_dict = worker_dicts[worker_id] 108 | name = f"worker_{worker_id}" 109 | return LabelingFunction(name, f=worker_lf, resources={"worker_dict": worker_dict}) 110 | 111 | 112 | worker_lfs = [make_worker_lf(worker_id) for worker_id in worker_dicts] 113 | 114 | # %% [markdown] 115 | # Let's take a quick look at how well they do on the development set. 116 | 117 | # %% {"tags": ["md-exclude-output"]} 118 | from snorkel.labeling import PandasLFApplier 119 | 120 | applier = PandasLFApplier(worker_lfs) 121 | L_train = applier.apply(df_train) 122 | L_dev = applier.apply(df_dev) 123 | 124 | # %% [markdown] 125 | # Note that because our dev set is so small and our LFs are relatively sparse, many LFs will appear to have zero coverage. 126 | # Fortunately, our label model learns weights for LFs based on their outputs on the training set, which is generally much larger. 127 | 128 | # %% 129 | from snorkel.labeling import LFAnalysis 130 | 131 | LFAnalysis(L_dev, worker_lfs).lf_summary(Y_dev).sample(5) 132 | 133 | # %% [markdown] 134 | # So the crowd labels in general are quite good! But how much of our dev and training 135 | # sets do they cover? 136 | 137 | # %% 138 | print(f"Training set coverage: {100 * LFAnalysis(L_train).label_coverage(): 0.1f}%") 139 | print(f"Dev set coverage: {100 * LFAnalysis(L_dev).label_coverage(): 0.1f}%") 140 | 141 | # %% [markdown] 142 | # ### Additional labeling functions 143 | # 144 | # To improve coverage of the training set, we can mix the crowdworker labeling functions with labeling 145 | # functions of other types. 146 | # For example, we can use [TextBlob](https://textblob.readthedocs.io/en/dev/index.html), a tool that provides a pretrained sentiment analyzer. We run TextBlob on our tweets and create some simple LFs that threshold its polarity score, similar to what we did in the spam_tutorial. 147 | 148 | # %% 149 | from snorkel.labeling import labeling_function 150 | from snorkel.preprocess import preprocessor 151 | from textblob import TextBlob 152 | 153 | 154 | @preprocessor(memoize=True) 155 | def textblob_polarity(x): 156 | scores = TextBlob(x.tweet_text) 157 | x.polarity = scores.polarity 158 | return x 159 | 160 | 161 | # Label high polarity tweets as positive. 162 | @labeling_function(pre=[textblob_polarity]) 163 | def polarity_positive(x): 164 | return 1 if x.polarity > 0.3 else -1 165 | 166 | 167 | # Label low polarity tweets as negative. 168 | @labeling_function(pre=[textblob_polarity]) 169 | def polarity_negative(x): 170 | return 0 if x.polarity < -0.25 else -1 171 | 172 | 173 | # Similar to polarity_negative, but with higher coverage and lower precision. 174 | @labeling_function(pre=[textblob_polarity]) 175 | def polarity_negative_2(x): 176 | return 0 if x.polarity <= 0.3 else -1 177 | 178 | 179 | # %% [markdown] 180 | # ### Applying labeling functions to the training set 181 | 182 | # %% {"tags": ["md-exclude-output"]} 183 | text_lfs = [polarity_positive, polarity_negative, polarity_negative_2] 184 | lfs = text_lfs + worker_lfs 185 | 186 | applier = PandasLFApplier(lfs) 187 | L_train = applier.apply(df_train) 188 | L_dev = applier.apply(df_dev) 189 | 190 | # %% 191 | LFAnalysis(L_dev, lfs).lf_summary(Y_dev).head() 192 | 193 | # %% [markdown] 194 | # Using the text-based LFs, we've expanded coverage on both our training set 195 | # and dev set to 100%. 196 | # We'll now take these noisy and conflicting labels, and use the LabelModel 197 | # to denoise and combine them. 198 | 199 | # %% 200 | print(f"Training set coverage: {100 * LFAnalysis(L_train).label_coverage(): 0.1f}%") 201 | print(f"Dev set coverage: {100 * LFAnalysis(L_dev).label_coverage(): 0.1f}%") 202 | 203 | # %% [markdown] 204 | # ## Train LabelModel And Generate Probabilistic Labels 205 | 206 | # %% {"tags": ["md-exclude-output"]} 207 | from snorkel.labeling.model import LabelModel 208 | 209 | # Train LabelModel. 210 | label_model = LabelModel(cardinality=2, verbose=True) 211 | label_model.fit(L_train, n_epochs=100, seed=123, log_freq=20, l2=0.1, lr=0.01) 212 | 213 | # %% [markdown] 214 | # As a spot-check for the quality of our LabelModel, we'll score it on the dev set. 215 | 216 | # %% 217 | from snorkel.analysis import metric_score 218 | 219 | preds_dev = label_model.predict(L_dev) 220 | 221 | acc = metric_score(Y_dev, preds_dev, probs=None, metric="accuracy") 222 | print(f"LabelModel Accuracy: {acc:.3f}") 223 | 224 | # %% [markdown] 225 | # We see that we get very high accuracy on the development set. 226 | # This is due to the abundance of high quality crowdworker labels. 227 | # **Since we don't have these high quality crowdsourcing labels for the 228 | # test set or new incoming data points, we can't use the LabelModel reliably 229 | # at inference time.** 230 | # In order to run inference on new incoming data points, we need to train a 231 | # discriminative model over the tweets themselves. 232 | # Let's generate a set of labels for that training set. 233 | 234 | # %% 235 | preds_train = label_model.predict(L_train) 236 | 237 | # %% [markdown] 238 | # ## Use Soft Labels to Train End Model 239 | 240 | # %% [markdown] 241 | # ### Getting features from BERT 242 | # Since we have very limited training data, we cannot train a complex model like an LSTM with a lot of parameters. 243 | # Instead, we use a pre-trained model, [BERT](https://github.com/google-research/bert), to generate embeddings for each our tweets, and treat the embedding values as features. 244 | # This may take 5-10 minutes on a CPU, as the BERT model is very large. 245 | 246 | # %% {"tags": ["md-exclude-output"]} 247 | import numpy as np 248 | import torch 249 | from pytorch_transformers import BertModel, BertTokenizer 250 | 251 | model = BertModel.from_pretrained("bert-base-uncased") 252 | tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") 253 | 254 | 255 | def encode_text(text): 256 | input_ids = torch.tensor([tokenizer.encode(text)]) 257 | return model(input_ids)[0].mean(1)[0].detach().numpy() 258 | 259 | 260 | X_train = np.array(list(df_train.tweet_text.apply(encode_text).values)) 261 | X_test = np.array(list(df_test.tweet_text.apply(encode_text).values)) 262 | 263 | # %% [markdown] 264 | # ### Model on labels 265 | # Now, we train a simple logistic regression model on the BERT features, using labels 266 | # obtained from our LabelModel. 267 | 268 | # %% {"tags": ["md-exclude-output"]} 269 | from sklearn.linear_model import LogisticRegression 270 | 271 | sklearn_model = LogisticRegression(solver="liblinear") 272 | sklearn_model.fit(X_train, preds_train) 273 | 274 | # %% 275 | print(f"Accuracy of trained model: {sklearn_model.score(X_test, Y_test)}") 276 | 277 | # %% [markdown] 278 | # We now have a trained model that can be applied to future data points without requiring crowdsourced labels, and with accuracy not much lower than the `LabelModel` that _does_ have access to crowdsourced labels! 279 | 280 | # %% [markdown] 281 | # ## Summary 282 | # 283 | # In this tutorial, we accomplished the following: 284 | # * We demonstrated how to combine crowdsourced labels with other programmatic LFs to improve coverage. 285 | # * We used the `LabelModel` to combine inputs from crowdworkers and other LFs to generate high quality probabilistic labels. 286 | # * We used our labels to train a classifier for making predictions on new, unseen data points. 287 | -------------------------------------------------------------------------------- /scripts/build.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import re 4 | import subprocess 5 | import tempfile 6 | import urllib 7 | import yaml 8 | from typing import List, Optional 9 | 10 | import click 11 | import jupytext 12 | from jupytext.compare import compare_notebooks 13 | 14 | logging.basicConfig(level=logging.INFO) 15 | 16 | 17 | NOTEBOOKS_CONFIG_FNAME = ".notebooks" 18 | SCRIPTS_CONFIG_FNAME = ".scripts" 19 | EXCLUDE_CELL_TAG = "md-exclude" 20 | EXCLUDE_OUTPUT_TAG = "md-exclude-output" 21 | BUILD_DIR = "build" 22 | WEB_YML = ".web.yml" 23 | 24 | 25 | HEADER_TEMPLATE = """--- 26 | layout: default 27 | title: {title} 28 | description: {description} 29 | excerpt: {description} 30 | order: {order} 31 | github_link: {github_link} 32 | --- 33 | 34 | """ 35 | 36 | 37 | GITHUB_LINK_TEMPLATE = ( 38 | "https://github.com/snorkel-team/snorkel-tutorials/blob/master/{notebook_path}" 39 | ) 40 | 41 | 42 | # Credit to: https://gist.github.com/pchc2005/b5f13e136a9c9bb2984e5b92802fc7c9 43 | # Original source: https://gist.github.com/dperini/729294 44 | MARKDOWN_URL_REGEX = re.compile( 45 | "\(" 46 | # protocol identifier 47 | "(?:(?:(?:https?|ftp):)?//)" 48 | # user:pass authentication 49 | "(?:\S+(?::\S*)?@)?" "(?:" 50 | # IP address exclusion 51 | # private & local networks 52 | "(?!(?:10|127)(?:\.\d{1,3}){3})" 53 | "(?!(?:169\.254|192\.168)(?:\.\d{1,3}){2})" 54 | "(?!172\.(?:1[6-9]|2\d|3[0-1])(?:\.\d{1,3}){2})" 55 | # IP address dotted notation octets 56 | # excludes loopback network 0.0.0.0 57 | # excludes reserved space >= 224.0.0.0 58 | # excludes network & broadcast addresses 59 | # (first & last IP address of each class) 60 | "(?:[1-9]\d?|1\d\d|2[01]\d|22[0-3])" 61 | "(?:\.(?:1?\d{1,2}|2[0-4]\d|25[0-5])){2}" 62 | "(?:\.(?:[1-9]\d?|1\d\d|2[0-4]\d|25[0-4]))" 63 | "|" 64 | # host & domain names, may end with dot 65 | # can be replaced by a shortest alternative 66 | # u"(?![-_])(?:[-\w\u00a1-\uffff]{0,63}[^-_]\.)+" 67 | # u"(?:(?:[a-z\u00a1-\uffff0-9]-?)*[a-z\u00a1-\uffff0-9]+)" 68 | # # domain name 69 | # u"(?:\.(?:[a-z\u00a1-\uffff0-9]-?)*[a-z\u00a1-\uffff0-9]+)*" 70 | "(?:" 71 | "(?:" 72 | "[a-z0-9\u00a1-\uffff]" 73 | "[a-z0-9\u00a1-\uffff_-]{0,62}" 74 | ")?" 75 | "[a-z0-9\u00a1-\uffff]\." 76 | ")+" 77 | # TLD identifier name, may end with dot 78 | "(?:[a-z\u00a1-\uffff]{2,}\.?)" ")" 79 | # port number (optional) 80 | "(?::\d{2,5})?" 81 | # resource path (optional) 82 | "(?:[/?#]\S*)?" "\)", 83 | re.UNICODE | re.I, 84 | ) 85 | 86 | 87 | class Notebook: 88 | def __init__(self, notebook_path: str) -> None: 89 | self.basename = os.path.splitext(notebook_path)[0] 90 | 91 | @property 92 | def py(self) -> str: 93 | return f"{self.basename}.py" 94 | 95 | @property 96 | def ipynb(self) -> str: 97 | return f"{self.basename}.ipynb" 98 | 99 | 100 | class MarkdownHeader: 101 | def __init__( 102 | self, title: str, description: str, order: int, github_link: str 103 | ) -> None: 104 | self.title = title 105 | self.description = description 106 | self.order = order 107 | self.github_link = github_link 108 | 109 | def render(self): 110 | return HEADER_TEMPLATE.format( 111 | title=self.title, 112 | description=self.description, 113 | order=self.order, 114 | github_link=self.github_link, 115 | ) 116 | 117 | 118 | class TutorialWebpage: 119 | def __init__( 120 | self, 121 | ipynb_path: str, 122 | header: Optional[MarkdownHeader], 123 | exclude_all_output: bool, 124 | ) -> None: 125 | self.ipynb = ipynb_path 126 | self.header = header 127 | self.exclude_all_output = exclude_all_output 128 | 129 | def markdown_path(self) -> str: 130 | return os.path.join( 131 | BUILD_DIR, f"{os.path.splitext(os.path.basename(self.ipynb))[0]}.md" 132 | ) 133 | 134 | 135 | def parse_web_yml(tutorial_dir: Optional[str]) -> List[TutorialWebpage]: 136 | # Read .web.yml 137 | with open(WEB_YML, "r") as f: 138 | web_config = yaml.safe_load(f) 139 | tutorial_webpages = [] 140 | # Process webpage configs in order 141 | i = 1 142 | for cfg in web_config["tutorials"]: 143 | # If tutorial directory specified, skip if not in specified directory 144 | notebook_path = cfg["notebook"] 145 | notebook_dir = notebook_path.split("/")[0] 146 | if tutorial_dir is not None and notebook_dir != tutorial_dir: 147 | continue 148 | # If full notebook path supplied, just use that 149 | if notebook_path.endswith(".ipynb"): 150 | notebook = Notebook(os.path.abspath(notebook_path)) 151 | # If only directory supply, ensure that there's only one notebook 152 | else: 153 | notebooks = get_notebooks(notebook_path) 154 | if len(notebooks) > 1: 155 | raise ValueError(f"Multiple notebooks found in {notebook_path}") 156 | notebook = notebooks[0] 157 | # If no title or description, don't generate order for header 158 | title = cfg.get("title") 159 | description = cfg.get("description") 160 | if title is not None and description is not None: 161 | full_notebook_path = notebook.ipynb.split("/snorkel-tutorials/")[-1] 162 | github_link = GITHUB_LINK_TEMPLATE.format(notebook_path=full_notebook_path) 163 | header = MarkdownHeader(title, description, i, github_link) 164 | i += 1 165 | else: 166 | header = None 167 | # Create TutorialWebpage object 168 | tutorial_webpages.append( 169 | TutorialWebpage( 170 | ipynb_path=notebook.ipynb, 171 | header=header, 172 | exclude_all_output=cfg.get("exclude_all_output", False), 173 | ) 174 | ) 175 | return tutorial_webpages 176 | 177 | 178 | def check_links(script_path: str) -> None: 179 | with open(script_path, "r") as f: 180 | contents = f.read() 181 | link_matches = list(MARKDOWN_URL_REGEX.finditer(contents)) 182 | for link_match in link_matches: 183 | url = link_match.group(0).rstrip(")").lstrip("(") 184 | req = urllib.request.Request(url, headers={"User-Agent": "Magic Browser"}) 185 | logging.info(f"Checking link [{url}]") 186 | try: 187 | urllib.request.urlopen(req, timeout=5) 188 | except urllib.error.HTTPError as e: 189 | raise ValueError(f"Bad link [{url}] found in {script_path}: {e}") 190 | except Exception as e: 191 | logging.warning( 192 | f"SKIPPING: Could not access [{url}] found in {script_path}: {e}" 193 | ) 194 | 195 | 196 | def call_jupytext(notebook: Notebook, out_fname: str, to_ipynb: bool) -> None: 197 | to_fmt = "ipynb" if to_ipynb else "py:percent" 198 | from_fmt = "py:percent" if to_ipynb else "ipynb" 199 | args = [ 200 | "jupytext", 201 | "--to", 202 | to_fmt, 203 | "--from", 204 | from_fmt, 205 | "--opt", 206 | "notebook_metadata_filter=-all", 207 | "--opt", 208 | "cell_metadata_filter=tags", 209 | notebook.py if to_ipynb else notebook.ipynb, 210 | "-o", 211 | out_fname, 212 | ] 213 | if to_ipynb: 214 | args.append("--execute") 215 | subprocess.run(args, check=True) 216 | 217 | 218 | def get_notebooks(tutorial_dir: str) -> List[Notebook]: 219 | path = os.path.abspath(tutorial_dir) 220 | config_path = os.path.join(path, NOTEBOOKS_CONFIG_FNAME) 221 | if not os.path.isfile(config_path): 222 | logging.info(f"No {NOTEBOOKS_CONFIG_FNAME} config file in {path}") 223 | return [] 224 | with open(config_path, "r") as f: 225 | notebooks = f.read().splitlines() 226 | return [Notebook(os.path.join(path, nb)) for nb in notebooks if nb] 227 | 228 | 229 | def get_scripts(tutorial_dir: str) -> List[Notebook]: 230 | path = os.path.abspath(tutorial_dir) 231 | config_path = os.path.join(path, SCRIPTS_CONFIG_FNAME) 232 | if not os.path.isfile(config_path): 233 | logging.info(f"No {SCRIPTS_CONFIG_FNAME} config file in {path}") 234 | return [] 235 | with open(config_path, "r") as f: 236 | scripts = [os.path.join(path, s) for s in f.read().splitlines() if s] 237 | return scripts 238 | 239 | 240 | def check_notebook(notebook: Notebook) -> None: 241 | assert os.path.exists(notebook.py), f"No file {notebook.py}" 242 | os.environ["IS_TEST"] = "true" 243 | logging.info(f"Checking links in [{notebook.py}]") 244 | check_links(notebook.py) 245 | notebook_actual = jupytext.read(notebook.ipynb, fmt=dict(extension="ipynb")) 246 | with tempfile.NamedTemporaryFile(suffix=".ipynb") as f: 247 | logging.info(f"Executing notebook [{notebook.py}]") 248 | call_jupytext(notebook, f.name, to_ipynb=True) 249 | notebook_expected = jupytext.read(f.name, fmt=dict(extension="ipynb")) 250 | # notebook_metadata_filter gets flipped during execution. Remove it to ensure 251 | # all metadata is tested. 252 | notebook_actual.metadata.get("jupytext", {}).pop( 253 | "notebook_metadata_filter", None 254 | ) 255 | notebook_expected.metadata.get("jupytext", {}).pop( 256 | "notebook_metadata_filter", None 257 | ) 258 | compare_notebooks(notebook_actual, notebook_expected) 259 | 260 | 261 | def check_script(script_path: str) -> None: 262 | assert os.path.exists(script_path), f"No file {script_path}" 263 | logging.info(f"Checking links in [{script_path}]") 264 | check_links(script_path) 265 | logging.info(f"Executing script [{script_path}]") 266 | check_run = subprocess.run(["python", script_path]) 267 | if check_run.returncode: 268 | raise ValueError(f"Error running {script_path}") 269 | 270 | 271 | def build_markdown_notebook(tutorial: TutorialWebpage) -> None: 272 | assert os.path.exists(tutorial.ipynb), f"No file {tutorial.ipynb}" 273 | os.makedirs(BUILD_DIR, exist_ok=True) 274 | # Call nbconvert 275 | args = [ 276 | "jupyter", 277 | "nbconvert", 278 | tutorial.ipynb, 279 | "--to", 280 | "markdown", 281 | f"--TagRemovePreprocessor.remove_cell_tags={{'{EXCLUDE_CELL_TAG}'}}", 282 | f"--TagRemovePreprocessor.remove_all_outputs_tags={{'{EXCLUDE_OUTPUT_TAG}'}}", 283 | "--output-dir", 284 | BUILD_DIR, 285 | ] 286 | if tutorial.exclude_all_output: 287 | args.append("--TemplateExporter.exclude_output=True") 288 | subprocess.run(args, check=True) 289 | # Prepend header by reading generated file then writing back 290 | if tutorial.header is not None: 291 | with open(tutorial.markdown_path(), "r") as f: 292 | content = f.read() 293 | with open(tutorial.markdown_path(), "w") as f: 294 | f.write(tutorial.header.render() + content) 295 | 296 | 297 | def sync_notebook(notebook: Notebook) -> None: 298 | assert os.path.exists(notebook.py), f"No file {notebook.py}" 299 | call_jupytext(notebook, notebook.ipynb, to_ipynb=True) 300 | 301 | 302 | def sync_py(notebook: Notebook) -> None: 303 | assert os.path.exists(notebook.ipynb), f"No file {notebook.ipynb}" 304 | call_jupytext(notebook, notebook.py, to_ipynb=False) 305 | 306 | 307 | @click.group() 308 | def cli() -> None: 309 | pass 310 | 311 | 312 | @cli.command() 313 | @click.argument("tutorial_dir") 314 | def test(tutorial_dir: str) -> None: 315 | for notebook in get_notebooks(tutorial_dir): 316 | check_notebook(notebook) 317 | for script in get_scripts(tutorial_dir): 318 | check_script(script) 319 | 320 | 321 | @cli.command() 322 | @click.option("--tutorial-dir") 323 | def markdown(tutorial_dir: Optional[str]) -> None: 324 | for tutorial_webpage in parse_web_yml(tutorial_dir): 325 | build_markdown_notebook(tutorial_webpage) 326 | 327 | 328 | @cli.command() 329 | @click.argument("tutorial_dir") 330 | @click.option("--py", is_flag=True) 331 | def sync(tutorial_dir: str, py: bool) -> None: 332 | for notebook in get_notebooks(tutorial_dir): 333 | if py: 334 | sync_py(notebook) 335 | else: 336 | sync_notebook(notebook) 337 | 338 | 339 | if __name__ == "__main__": 340 | cli() 341 | -------------------------------------------------------------------------------- /spouse/spouse_demo.py: -------------------------------------------------------------------------------- 1 | # %% [markdown] 2 | # # Detecting spouse mentions in sentences 3 | 4 | # %% [markdown] 5 | # In this tutorial, we will see how Snorkel can be used for Information Extraction. We will walk through an example text classification task for information extraction, where we use labeling functions involving keywords and distant supervision. 6 | # ### Classification Task 7 | # 8 | # 9 | # We want to classify each __candidate__ or pair of people mentioned in a sentence, as being married at some point or not. 10 | # 11 | # In the above example, our candidate represents the possible relation `(Barack Obama, Michelle Obama)`. As readers, we know this mention is true due to external knowledge and the keyword of `wedding` occuring later in the sentence. 12 | # We begin with some basic setup and data downloading. 13 | # 14 | # %% {"tags": ["md-exclude"]} 15 | # %matplotlib inline 16 | 17 | import os 18 | import pandas as pd 19 | import pickle 20 | 21 | if os.path.basename(os.getcwd()) == "snorkel-tutorials": 22 | os.chdir("spouse") 23 | 24 | # %% 25 | from utils import load_data 26 | 27 | ((df_dev, Y_dev), df_train, (df_test, Y_test)) = load_data() 28 | 29 | # %% [markdown] 30 | # **Input Data:** `df_dev`, `df_train`, and `df_test` are `Pandas DataFrame` objects, where each row represents a particular __candidate__. For our problem, a candidate consists of a sentence, and two people mentioned in the sentence. The DataFrames contain the fields `sentence`, which refers to the sentence of the candidate, `tokens`, the tokenized form of the sentence, and `person1_word_idx` and `person2_word_idx`, which represent `[start, end]` indices in the tokens at which the first and second person's name appear, respectively. 31 | # 32 | # We also have certain **preprocessed fields**, that we discuss a few cells below. 33 | 34 | # %% {"tags": ["md-exclude"]} 35 | 36 | # Don't truncate text fields in the display 37 | pd.set_option("display.max_colwidth", 0) 38 | 39 | df_dev.head() 40 | 41 | # %% [markdown] 42 | # Let's look at a candidate in the development set: 43 | 44 | # %% 45 | from preprocessors import get_person_text 46 | 47 | candidate = df_dev.loc[2] 48 | person_names = get_person_text(candidate).person_names 49 | 50 | print("Sentence: ", candidate["sentence"]) 51 | print("Person 1: ", person_names[0]) 52 | print("Person 2: ", person_names[1]) 53 | 54 | # %% [markdown] 55 | # ### Preprocessing the Data 56 | # 57 | # In a real application, there is a lot of data preparation, parsing, and database loading that needs to be completed before we generate candidates and dive into writing labeling functions. Here we've pre-generated candidates in a pandas DataFrame object per split (train,dev,test). 58 | 59 | # %% [markdown] 60 | # ### Labeling Function Helpers 61 | # 62 | # When writing labeling functions, there are several functions you will use over and over again. In the case of text relation extraction as with this task, common functions include those for fetching text between mentions of the two people in a candidate, examing word windows around person mentions, and so on. We will wrap these functions as `preprocessors`. 63 | 64 | # %% 65 | from snorkel.preprocess import preprocessor 66 | 67 | 68 | @preprocessor() 69 | def get_text_between(cand): 70 | """ 71 | Returns the text between the two person mentions in the sentence for a candidate 72 | """ 73 | start = cand.person1_word_idx[1] + 1 74 | end = cand.person2_word_idx[0] 75 | cand.text_between = " ".join(cand.tokens[start:end]) 76 | return cand 77 | 78 | 79 | # %% [markdown] 80 | # ### Candidate PreProcessors 81 | # 82 | # For the purposes of the tutorial, we have three fields (`between_tokens`, `person1_right_tokens`, `person2_right_tokens`) preprocessed in the data, which can be used when creating labeling functions. We also provide the following set of `preprocessor`s for this task in `preprocessors.py`, along with the fields these populate. 83 | # * `get_person_text(cand)`: `person_names` 84 | # * `get_person_lastnames(cand)`: `person_lastnames` 85 | # * `get_left_tokens(cand)`: `person1_left_tokens`, `person2_left_tokens` 86 | 87 | # %% 88 | from preprocessors import get_left_tokens, get_person_last_names 89 | 90 | POSITIVE = 1 91 | NEGATIVE = 0 92 | ABSTAIN = -1 93 | 94 | # %% 95 | from snorkel.labeling import labeling_function 96 | 97 | # Check for the `spouse` words appearing between the person mentions 98 | spouses = {"spouse", "wife", "husband", "ex-wife", "ex-husband"} 99 | 100 | 101 | @labeling_function(resources=dict(spouses=spouses)) 102 | def lf_husband_wife(x, spouses): 103 | return POSITIVE if len(spouses.intersection(set(x.between_tokens))) > 0 else ABSTAIN 104 | 105 | 106 | # %% 107 | # Check for the `spouse` words appearing to the left of the person mentions 108 | @labeling_function(resources=dict(spouses=spouses), pre=[get_left_tokens]) 109 | def lf_husband_wife_left_window(x, spouses): 110 | if len(set(spouses).intersection(set(x.person1_left_tokens))) > 0: 111 | return POSITIVE 112 | elif len(set(spouses).intersection(set(x.person2_left_tokens))) > 0: 113 | return POSITIVE 114 | else: 115 | return ABSTAIN 116 | 117 | 118 | # %% 119 | # Check for the person mentions having the same last name 120 | @labeling_function(pre=[get_person_last_names]) 121 | def lf_same_last_name(x): 122 | p1_ln, p2_ln = x.person_lastnames 123 | 124 | if p1_ln and p2_ln and p1_ln == p2_ln: 125 | return POSITIVE 126 | return ABSTAIN 127 | 128 | 129 | # %% 130 | # Check for the word `married` between person mentions 131 | @labeling_function() 132 | def lf_married(x): 133 | return POSITIVE if "married" in x.between_tokens else ABSTAIN 134 | 135 | 136 | # %% 137 | # Check for words that refer to `family` relationships between and to the left of the person mentions 138 | family = { 139 | "father", 140 | "mother", 141 | "sister", 142 | "brother", 143 | "son", 144 | "daughter", 145 | "grandfather", 146 | "grandmother", 147 | "uncle", 148 | "aunt", 149 | "cousin", 150 | } 151 | family = family.union({f + "-in-law" for f in family}) 152 | 153 | 154 | @labeling_function(resources=dict(family=family)) 155 | def lf_familial_relationship(x, family): 156 | return NEGATIVE if len(family.intersection(set(x.between_tokens))) > 0 else ABSTAIN 157 | 158 | 159 | @labeling_function(resources=dict(family=family), pre=[get_left_tokens]) 160 | def lf_family_left_window(x, family): 161 | if len(set(family).intersection(set(x.person1_left_tokens))) > 0: 162 | return NEGATIVE 163 | elif len(set(family).intersection(set(x.person2_left_tokens))) > 0: 164 | return NEGATIVE 165 | else: 166 | return ABSTAIN 167 | 168 | 169 | # %% 170 | # Check for `other` relationship words between person mentions 171 | other = {"boyfriend", "girlfriend", "boss", "employee", "secretary", "co-worker"} 172 | 173 | 174 | @labeling_function(resources=dict(other=other)) 175 | def lf_other_relationship(x, other): 176 | return NEGATIVE if len(other.intersection(set(x.between_tokens))) > 0 else ABSTAIN 177 | 178 | 179 | # %% [markdown] 180 | # ### Distant Supervision Labeling Functions 181 | # 182 | # In addition to using factories that encode pattern matching heuristics, we can also write labeling functions that _distantly supervise_ data points. Here, we'll load in a list of known spouse pairs and check to see if the pair of persons in a candidate matches one of these. 183 | # 184 | # [**DBpedia**](http://wiki.dbpedia.org/): Our database of known spouses comes from DBpedia, which is a community-driven resource similar to Wikipedia but for curating structured data. We'll use a preprocessed snapshot as our knowledge base for all labeling function development. 185 | # 186 | # We can look at some of the example entries from DBPedia and use them in a simple distant supervision labeling function. 187 | # 188 | # Make sure `dbpedia.pkl` is in the `spouse/data` directory. 189 | 190 | # %% 191 | with open("data/dbpedia.pkl", "rb") as f: 192 | known_spouses = pickle.load(f) 193 | 194 | list(known_spouses)[0:5] 195 | 196 | 197 | # %% 198 | @labeling_function(resources=dict(known_spouses=known_spouses), pre=[get_person_text]) 199 | def lf_distant_supervision(x, known_spouses): 200 | p1, p2 = x.person_names 201 | if (p1, p2) in known_spouses or (p2, p1) in known_spouses: 202 | return POSITIVE 203 | else: 204 | return ABSTAIN 205 | 206 | 207 | # %% 208 | from preprocessors import last_name 209 | 210 | # Last name pairs for known spouses 211 | last_names = set( 212 | [ 213 | (last_name(x), last_name(y)) 214 | for x, y in known_spouses 215 | if last_name(x) and last_name(y) 216 | ] 217 | ) 218 | 219 | 220 | @labeling_function(resources=dict(last_names=last_names), pre=[get_person_last_names]) 221 | def lf_distant_supervision_last_names(x, last_names): 222 | p1_ln, p2_ln = x.person_lastnames 223 | 224 | return ( 225 | POSITIVE 226 | if (p1_ln != p2_ln) 227 | and ((p1_ln, p2_ln) in last_names or (p2_ln, p1_ln) in last_names) 228 | else ABSTAIN 229 | ) 230 | 231 | 232 | # %% [markdown] 233 | # #### Apply Labeling Functions to the Data 234 | # We create a list of labeling functions and apply them to the data 235 | 236 | # %% 237 | from snorkel.labeling import PandasLFApplier 238 | 239 | lfs = [ 240 | lf_husband_wife, 241 | lf_husband_wife_left_window, 242 | lf_same_last_name, 243 | lf_married, 244 | lf_familial_relationship, 245 | lf_family_left_window, 246 | lf_other_relationship, 247 | lf_distant_supervision, 248 | lf_distant_supervision_last_names, 249 | ] 250 | applier = PandasLFApplier(lfs) 251 | 252 | # %% {"tags": ["md-exclude-output"]} 253 | from snorkel.labeling import LFAnalysis 254 | 255 | L_dev = applier.apply(df_dev) 256 | L_train = applier.apply(df_train) 257 | 258 | # %% 259 | LFAnalysis(L_dev, lfs).lf_summary(Y_dev) 260 | 261 | # %% [markdown] 262 | # ### Training the Label Model 263 | # 264 | # Now, we'll train a model of the LFs to estimate their weights and combine their outputs. Once the model is trained, we can combine the outputs of the LFs into a single, noise-aware training label set for our extractor. 265 | 266 | # %% {"tags": ["md-exclude-output"]} 267 | from snorkel.labeling.model import LabelModel 268 | 269 | label_model = LabelModel(cardinality=2, verbose=True) 270 | label_model.fit(L_train, Y_dev, n_epochs=5000, log_freq=500, seed=12345) 271 | 272 | # %% [markdown] 273 | # ### Label Model Metrics 274 | # Since our dataset is highly unbalanced (91% of the labels are negative), even a trivial baseline that always outputs negative can get a high accuracy. So we evaluate the label model using the F1 score and ROC-AUC rather than accuracy. 275 | 276 | # %% 277 | from snorkel.analysis import metric_score 278 | from snorkel.utils import probs_to_preds 279 | 280 | probs_dev = label_model.predict_proba(L_dev) 281 | preds_dev = probs_to_preds(probs_dev) 282 | print( 283 | f"Label model f1 score: {metric_score(Y_dev, preds_dev, probs=probs_dev, metric='f1')}" 284 | ) 285 | print( 286 | f"Label model roc-auc: {metric_score(Y_dev, preds_dev, probs=probs_dev, metric='roc_auc')}" 287 | ) 288 | 289 | # %% [markdown] 290 | # ### Part 4: Training our End Extraction Model 291 | # 292 | # In this final section of the tutorial, we'll use our noisy training labels to train our end machine learning model. We start by filtering out training data points which did not recieve a label from any LF, as these data points contain no signal. 293 | # 294 | # %% 295 | from snorkel.labeling import filter_unlabeled_dataframe 296 | 297 | probs_train = label_model.predict_proba(L_train) 298 | df_train_filtered, probs_train_filtered = filter_unlabeled_dataframe( 299 | X=df_train, y=probs_train, L=L_train 300 | ) 301 | 302 | # %% [markdown] 303 | # Next, we train a simple [LSTM](https://en.wikipedia.org/wiki/Long_short-term_memory) network for classifying candidates. `tf_model` contains functions for processing features and building the keras model for training and evaluation. 304 | 305 | # %% {"tags": ["md-exclude-output"]} 306 | from tf_model import get_model, get_feature_arrays 307 | from utils import get_n_epochs 308 | 309 | X_train = get_feature_arrays(df_train_filtered) 310 | model = get_model() 311 | batch_size = 64 312 | model.fit(X_train, probs_train_filtered, batch_size=batch_size, epochs=get_n_epochs()) 313 | 314 | # %% [markdown] 315 | # Finally, we evaluate the trained model by measuring its F1 score and ROC_AUC. 316 | 317 | # %% 318 | X_test = get_feature_arrays(df_test) 319 | probs_test = model.predict(X_test) 320 | preds_test = probs_to_preds(probs_test) 321 | print( 322 | f"Test F1 when trained with soft labels: {metric_score(Y_test, preds=preds_test, metric='f1')}" 323 | ) 324 | print( 325 | f"Test ROC-AUC when trained with soft labels: {metric_score(Y_test, probs=probs_test, metric='roc_auc')}" 326 | ) 327 | 328 | # %% [markdown] 329 | # ## Summary 330 | # In this tutorial, we showed how Snorkel can be used for Information Extraction. We demonstrated how to create LFs that leverage keywords and external knowledge bases (distant supervision). Finally, we showed how a model trained using the probabilistic outputs of the Label Model can achieve comparable performance while generalizing to all data points. 331 | -------------------------------------------------------------------------------- /getting_started/getting_started.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # %% [markdown] {"tags": ["md-exclude"]} 3 | # # Getting Started with Snorkel 4 | 5 | # %% {"tags": ["md-exclude"]} 6 | import os 7 | 8 | # Make sure we're running from the spam/ directory 9 | if os.path.basename(os.getcwd()) == "snorkel-tutorials": 10 | os.chdir("getting_started") 11 | 12 | # %% [markdown] 13 | # ## Programmatically Building and Managing Training Data with Snorkel 14 | # 15 | # Snorkel is a system for _programmatically_ building and managing training datasets **without manual labeling**. 16 | # In Snorkel, users can develop large training datasets in hours or days rather than hand-labeling them over weeks or months. 17 | # 18 | # Snorkel currently exposes three key programmatic operations: 19 | # - **Labeling data**, e.g., using heuristic rules or distant supervision techniques 20 | # - **Transforming data**, e.g., rotating or stretching images to perform data augmentation 21 | # - **Slicing data** into different critical subsets for monitoring or targeted improvement 22 | # 23 | # Snorkel then automatically models, cleans, and integrates the resulting training data using novel, theoretically-grounded techniques. 24 | 25 | # %% [markdown] 26 | # 27 | 28 | # %% [markdown] 29 | # In this quick walkthrough, we'll preview the high-level workflow and interfaces of Snorkel using a canonical machine learning problem: classifying spam. 30 | # We'll use a public [YouTube comments dataset](http://www.dt.fee.unicamp.br/~tiago//youtubespamcollection/), and see how **Snorkel can enable training a machine learning model without _any_ hand-labeled training data!** 31 | # For more detailed versions of the sections in this walkthrough, see the corresponding tutorials: ([Spam LFs](https://snorkel.org/use-cases/01-spam-tutorial), [Spam TFs](https://snorkel.org/use-cases/02-spam-data-augmentation-tutorial), [Spam SFs](https://snorkel.org/use-cases/03-spam-data-slicing-tutorial)). 32 | 33 | # %% [markdown] 34 | # We'll walk through five basic steps: 35 | # 36 | # 1. **Writing Labeling Functions (LFs):** First, rather than hand-labeling any training data, we'll programmatically label our _unlabeled_ dataset with LFs. 37 | # 2. **Modeling & Combining LFs:** Next, we'll use Snorkel's `LabelModel` to automatically learn the accuracies of our LFs and reweight and combine their outputs into a single, confidence-weighted training label per data point. 38 | # 3. **Writing Transformation Functions (TFs) for Data Augmentation:** Then, we'll augment this labeled training set by writing a simple TF. 39 | # 4. **Writing _Slicing Functions (SFs)_ for Data Subset Selection:** We'll also preview writing an SF to identify a critical subset or _slice_ of our training set. 40 | # 5. **Training a final ML model:** Finally, we'll train an ML model with our training set. 41 | # 42 | # We'll start first by loading the _unlabeled_ comments, which we'll use as our training data, as a Pandas `DataFrame`: 43 | 44 | # %% 45 | from utils import load_unlabeled_spam_dataset 46 | 47 | df_train = load_unlabeled_spam_dataset() 48 | 49 | # %% [markdown] 50 | # ## 1) Writing Labeling Functions 51 | # 52 | # _Labeling functions (LFs)_ are one of the core operators for building and managing training datasets programmatically in Snorkel. 53 | # The basic idea is simple: **a labeling function is a function that outputs a label for some subset of the training dataset**. 54 | # In our example here, each labeling function takes as input a comment data point, and either outputs a label (`SPAM = 1` or `NOT_SPAM = 0`) or abstains from labeling (`ABSTAIN = -1`): 55 | 56 | # %% 57 | # Define the label mappings for convenience 58 | ABSTAIN = -1 59 | NOT_SPAM = 0 60 | SPAM = 1 61 | 62 | # %% [markdown] 63 | # Labeling functions can be used to represent many heuristic and/or noisy strategies for labeling data, often referred to as [weak supervision](https://www.snorkel.org/blog/weak-supervision). 64 | # The basic idea of labeling functions, and other programmatic operators in Snorkel, is to let users inject domain information into machine learning models in higher level, higher bandwidth ways than manually labeling thousands or millions of individual data points. 65 | # **The key idea is that labeling functions do not need to be perfectly accurate**, and can in fact even be correlated with each other. 66 | # Snorkel will automatically estimate their accuracies and correlations in a [provably consistent way](https://papers.nips.cc/paper/6523-data-programming-creating-large-training-sets-quickly), and then reweight and combine their output labels, leading to high-quality training labels. 67 | 68 | # %% [markdown] 69 | # In our text data setting here, labeling functions use: 70 | # 71 | # Keyword matches: 72 | 73 | # %% 74 | from snorkel.labeling import labeling_function 75 | 76 | 77 | @labeling_function() 78 | def lf_keyword_my(x): 79 | """Many spam comments talk about 'my channel', 'my video', etc.""" 80 | return SPAM if "my" in x.text.lower() else ABSTAIN 81 | 82 | 83 | # %% [markdown] 84 | # Regular expressions: 85 | 86 | # %% 87 | import re 88 | 89 | 90 | @labeling_function() 91 | def lf_regex_check_out(x): 92 | """Spam comments say 'check out my video', 'check it out', etc.""" 93 | return SPAM if re.search(r"check.*out", x.text, flags=re.I) else ABSTAIN 94 | 95 | 96 | # %% [markdown] 97 | # Arbitrary heuristics: 98 | 99 | # %% 100 | @labeling_function() 101 | def lf_short_comment(x): 102 | """Non-spam comments are often short, such as 'cool video!'.""" 103 | return NOT_SPAM if len(x.text.split()) < 5 else ABSTAIN 104 | 105 | 106 | # %% [markdown] 107 | # Third-party models: 108 | 109 | # %% 110 | from textblob import TextBlob 111 | 112 | 113 | @labeling_function() 114 | def lf_textblob_polarity(x): 115 | """ 116 | We use a third-party sentiment classification model, TextBlob. 117 | 118 | We combine this with the heuristic that non-spam comments are often positive. 119 | """ 120 | return NOT_SPAM if TextBlob(x.text).sentiment.polarity > 0.3 else ABSTAIN 121 | 122 | 123 | # %% [markdown] 124 | # And much more! 125 | # For many more types of labeling functions — including over data modalities beyond text — see the other [tutorials](https://snorkel.org/use-cases/) and [real-world examples](https://snorkel.org/resources/). 126 | # 127 | # In general the process of developing labeling functions is, like any other development process, an iterative one that takes time. 128 | # However, in many cases it can be _orders-of-magnitude_ faster that hand-labeling training data. 129 | # For more detail on the process of developing labeling functions and other training data operators in Snorkel, see the [Introduction Tutorials](https://snorkel.org/use-cases). 130 | 131 | # %% [markdown] 132 | # ## 2) Combining & Cleaning the Labels 133 | # 134 | # Our next step is to apply the labeling functions we wrote to the unlabeled training data. 135 | # The result is a *label matrix*, `L_train`, where each row corresponds to a data point and each column corresponds to a labeling function. 136 | # Since the labeling functions have unknown accuracies and correlations, their output labels may overlap and conflict. 137 | # We use the `LabelModel` to automatically estimate their accuracies and correlations, reweight and combine their labels, and produce our final set of clean, integrated training labels: 138 | 139 | # %% 140 | from snorkel.labeling.model import LabelModel 141 | from snorkel.labeling import PandasLFApplier 142 | 143 | # Define the set of labeling functions (LFs) 144 | lfs = [lf_keyword_my, lf_regex_check_out, lf_short_comment, lf_textblob_polarity] 145 | 146 | # Apply the LFs to the unlabeled training data 147 | applier = PandasLFApplier(lfs) 148 | L_train = applier.apply(df_train) 149 | 150 | # Train the label model and compute the training labels 151 | label_model = LabelModel(cardinality=2, verbose=True) 152 | label_model.fit(L_train, n_epochs=500, log_freq=50, seed=123) 153 | df_train["label"] = label_model.predict(L=L_train, tie_break_policy="abstain") 154 | 155 | # %% [markdown] 156 | # Note that we used the `LabelModel` to label data; however, on many data points, all the labeling functions abstain, and so the `LabelModel` abstains as well. 157 | # We'll filter these data points out of our training set now: 158 | 159 | # %% 160 | df_train = df_train[df_train.label != ABSTAIN] 161 | 162 | # %% [markdown] 163 | # Our ultimate goal is to use the resulting labeled training data points to train a machine learning model that can **generalize beyond the coverage of the labeling functions and the `LabelModel`**. 164 | # However first we'll explore some of Snorkel's other operators for building and managing training data. 165 | 166 | # %% [markdown] 167 | # ## 3) Writing Transformation Functions for Data Augmentation 168 | # 169 | # An increasingly popular and critical technique in modern machine learning is [data augmentation](https://www.snorkel.org/blog/tanda), 170 | # the strategy of artificially *augmenting* existing labeled training datasets by creating transformed copies of the data points. 171 | # Data augmentation is a practical and powerful method for injecting information about domain invariances into ML models via the data, rather than by trying to modify their internal architectures. 172 | # The canonical example is randomly rotating, stretching, and transforming images when training image classifiers — a ubiquitous technique in the field of computer vision today. 173 | # However, data augmentation is increasingly used in a range of settings, including text. 174 | # 175 | # Here, we implement a simple text data augmentation strategy — randomly replacing a word with a synonym. 176 | # We express this as a *transformation function (TF)*: 177 | 178 | # %% 179 | import random 180 | 181 | import nltk 182 | from nltk.corpus import wordnet as wn 183 | 184 | from snorkel.augmentation import transformation_function 185 | 186 | nltk.download("wordnet", quiet=True) 187 | 188 | 189 | def get_synonyms(word): 190 | """Get the synonyms of word from Wordnet.""" 191 | lemmas = set().union(*[s.lemmas() for s in wn.synsets(word)]) 192 | return list(set(l.name().lower().replace("_", " ") for l in lemmas) - {word}) 193 | 194 | 195 | @transformation_function() 196 | def tf_replace_word_with_synonym(x): 197 | """Try to replace a random word with a synonym.""" 198 | words = x.text.lower().split() 199 | idx = random.choice(range(len(words))) 200 | synonyms = get_synonyms(words[idx]) 201 | if len(synonyms) > 0: 202 | x.text = " ".join(words[:idx] + [synonyms[0]] + words[idx + 1 :]) 203 | return x 204 | 205 | 206 | # %% [markdown] 207 | # Next, we apply this transformation function to our training dataset: 208 | 209 | # %% 210 | from snorkel.augmentation import ApplyOnePolicy, PandasTFApplier 211 | 212 | tf_policy = ApplyOnePolicy(n_per_original=2, keep_original=True) 213 | tf_applier = PandasTFApplier([tf_replace_word_with_synonym], tf_policy) 214 | df_train_augmented = tf_applier.apply(df_train) 215 | 216 | # %% [markdown] 217 | # Note that a common challenge with data augmentation is figuring out how to tune and apply different transformation functions to best augment a training set. 218 | # This is most commonly done as an ad hoc manual process; however, in Snorkel, various approaches for using automatically learned data augmentation _policies_ are supported. 219 | # For more detail, see the [Spam TFs tutorial](https://snorkel.org/use-cases/02-spam-data-augmentation-tutorial). 220 | 221 | # %% [markdown] 222 | # ## 4) Writing a Slicing Function 223 | # 224 | # Finally, a third operator in Snorkel, *slicing functions (SFs)*, handles the reality that many datasets have certain subsets or _slices_ that are more important than others. 225 | # In Snorkel, we can write SFs to (a) monitor specific slices and (b) improve model performance over them by adding representational capacity targeted on a per-slice basis. 226 | # 227 | # Writing a slicing function is simple. 228 | # For example, we could write one that looks for suspiciously shortened links, which might be critical due to their likelihood of linking to malicious sites: 229 | 230 | # %% 231 | from snorkel.slicing import slicing_function 232 | 233 | 234 | @slicing_function() 235 | def short_link(x): 236 | """Return whether text matches common pattern for shortened ".ly" links.""" 237 | return int(bool(re.search(r"\w+\.ly", x.text))) 238 | 239 | 240 | # %% [markdown] 241 | # We can now use Snorkel to monitor the performance over this slice, and to add representational capacity to our model in order to potentially increase performance on this slice. 242 | # For a walkthrough of these steps, see the [Spam SFs tutorial](https://snorkel.org/use-cases/03-spam-data-slicing-tutorial). 243 | 244 | # %% [markdown] 245 | # ## 5) Training a Classifier 246 | # 247 | # The ultimate goal in Snorkel is to **create a training dataset**, which can then be plugged into an arbitrary machine learning framework (e.g. TensorFlow, Keras, PyTorch, Scikit-Learn, Ludwig, XGBoost) to train powerful machine learning models. 248 | # Here, to complete this initial walkthrough, we'll train an extremely simple model — a "bag of n-grams" logistic regression model in Scikit-Learn — using the weakly labeled and augmented training set we made with our labeling and transformation functions: 249 | 250 | # %% 251 | from sklearn.feature_extraction.text import CountVectorizer 252 | from sklearn.linear_model import LogisticRegression 253 | 254 | train_text = df_train_augmented.text.tolist() 255 | X_train = CountVectorizer(ngram_range=(1, 2)).fit_transform(train_text) 256 | 257 | clf = LogisticRegression(solver="lbfgs") 258 | clf.fit(X=X_train, y=df_train_augmented.label.values) 259 | 260 | # %% [markdown] 261 | # And that's it — you've trained your first model **without hand-labeling _any_ training data!** 262 | # Next, to learn more about Snorkel, check out the [tutorials](https://snorkel.org/use-cases/), [resources](https://snorkel.org/resources), and [documentation](https://snorkel.readthedocs.io) for much more on how to use Snorkel to power your own machine learning applications. 263 | -------------------------------------------------------------------------------- /recsys/recsys_tutorial.py: -------------------------------------------------------------------------------- 1 | # %% [markdown] 2 | # # Recommender Systems Tutorial 3 | # In this tutorial, we'll provide a simple walkthrough of how to use Snorkel to build a recommender system. 4 | # We consider a setting similar to the [Netflix challenge](https://www.kaggle.com/netflix-inc/netflix-prize-data), but with books instead of movies. 5 | # We have a set of users and books, and for each user we know the set of books they have interacted with (read or marked as to-read). 6 | # We don't have the user's numerical ratings for the books they read, except in a small number of cases. 7 | # We also have some text reviews written by users. 8 | # 9 | # Our goal is to build a recommender system by training a classifier to predict whether a user will read and like any given book. 10 | # We'll train our model over a user-book pair to predict a `rating` (a `rating` of 1 means the user will read and like the book). 11 | # To simplify inference, we'll represent a user by the set of books they interacted with (rather than learning a specific representation for each user). 12 | # Once we have this model trained, we can use it to recommend books to a user when they visit the site. 13 | # For example, we can just predict the rating for the user paired with a book for a few thousand likely books, then pick the books with the top ten predicted ratings. 14 | # 15 | # Of course, there are many other ways to approach this problem. 16 | # The field of [recommender systems](https://en.wikipedia.org/wiki/Recommender_system) is a very well studied area with a wide variety of settings and approaches, and we just focus on one of them. 17 | # 18 | # We will use the [Goodreads](https://sites.google.com/eng.ucsd.edu/ucsdbookgraph/home) dataset, from 19 | # "Item Recommendation on Monotonic Behavior Chains", RecSys'18 (Mengting Wan, Julian McAuley), and "Fine-Grained Spoiler Detection from Large-Scale Review Corpora", ACL'19 (Mengting Wan, Rishabh Misra, Ndapa Nakashole, Julian McAuley). 20 | # In this dataset, we have user interactions and reviews for Young Adult novels from the Goodreads website, along with metadata (like `title` and `authors`) for the novels. 21 | 22 | # %% {"tags": ["md-exclude"]} 23 | import logging 24 | import os 25 | 26 | logging.basicConfig(level=logging.INFO) 27 | 28 | 29 | if os.path.basename(os.getcwd()) == "snorkel-tutorials": 30 | os.chdir("recsys") 31 | 32 | 33 | # %% [markdown] 34 | # ## Loading Data 35 | 36 | # %% [markdown] 37 | # We start by running the `download_and_process_data` function. 38 | # The function returns the `df_train`, `df_test`, `df_dev`, `df_valid` dataframes, which correspond to our training, test, development, and validation sets. 39 | # Each of those dataframes has the following fields: 40 | # * `user_idx`: A unique identifier for a user. 41 | # * `book_idx`: A unique identifier for a book that is being rated by the user. 42 | # * `book_idxs`: The set of books that the user has interacted with (read or planned to read). 43 | # * `review_text`: Optional text review written by the user for the book. 44 | # * `rating`: Either `0` (which means the user did not read or did not like the book) or `1` (which means the user read and liked the book). The `rating` field is missing for `df_train`. 45 | # Our objective is to predict whether a given user (represented by the set of book_idxs the user has interacted with) will read and like any given book. 46 | # That is, we want to train a model that takes a set of `book_idxs` (the user) and a single `book_idx` (the book to rate) and predicts the `rating`. 47 | # 48 | # In addition, `download_and_process_data` also returns the `df_books` dataframe, which contains one row per book, along with metadata for that book (such as `title` and `first_author`). 49 | 50 | # %% {"tags": ["md-exclude-output"]} 51 | from utils import download_and_process_data 52 | 53 | (df_train, df_test, df_dev, df_valid), df_books = download_and_process_data() 54 | 55 | df_books.head() 56 | 57 | # %% [markdown] 58 | # We look at a sample of the labeled development set. 59 | # As an example, we want our final recommendations model to be able to predict that a user who has interacted with `book_idxs` (25743, 22318, 7662, 6857, 83, 14495, 30664, ...) would either not read or not like the book with `book_idx` 22764 (first row), while a user who has interacted with `book_idxs` (3880, 18078, 9092, 29933, 1511, 8560, ...) would read and like the book with `book_idx` 3181 (second row). 60 | 61 | # %% 62 | df_dev.sample(frac=1, random_state=12).head() 63 | 64 | # %% [markdown] 65 | # ## Writing Labeling Functions 66 | 67 | # %% 68 | POSITIVE = 1 69 | NEGATIVE = 0 70 | ABSTAIN = -1 71 | 72 | # %% [markdown] 73 | # If a user has interacted with several books written by an author, there is a good chance that the user will read and like other books by the same author. 74 | # We express this as a labeling function, using the `first_author` field in the `df_books` dataframe. 75 | # We picked the threshold 15 by plotting histograms and running error analysis using the dev set. 76 | 77 | # %% 78 | from snorkel.labeling.lf import labeling_function 79 | 80 | book_to_first_author = dict(zip(df_books.book_idx, df_books.first_author)) 81 | first_author_to_books_df = df_books.groupby("first_author")[["book_idx"]].agg(set) 82 | first_author_to_books = dict( 83 | zip(first_author_to_books_df.index, first_author_to_books_df.book_idx) 84 | ) 85 | 86 | 87 | @labeling_function( 88 | resources=dict( 89 | book_to_first_author=book_to_first_author, 90 | first_author_to_books=first_author_to_books, 91 | ) 92 | ) 93 | def shared_first_author(x, book_to_first_author, first_author_to_books): 94 | author = book_to_first_author[x.book_idx] 95 | same_author_books = first_author_to_books[author] 96 | num_read = len(set(x.book_idxs).intersection(same_author_books)) 97 | return POSITIVE if num_read > 15 else ABSTAIN 98 | 99 | 100 | # %% [markdown] 101 | # We can also leverage the long text reviews written by users to guess whether they liked or disliked a book. 102 | # For example, the third `df_dev` entry above has a review with the text `'4.5 STARS'`, which indicates that the user liked the book. 103 | # We write a simple LF that looks for similar phrases to guess the user's rating of a book. 104 | # We interpret >= 4 stars to indicate a positive rating, while < 4 stars is negative. 105 | 106 | # %% 107 | low_rating_strs = [ 108 | "one star", 109 | "1 star", 110 | "two star", 111 | "2 star", 112 | "3 star", 113 | "three star", 114 | "3.5 star", 115 | "2.5 star", 116 | "1 out of 5", 117 | "2 out of 5", 118 | "3 out of 5", 119 | ] 120 | high_rating_strs = ["5 stars", "five stars", "four stars", "4 stars", "4.5 stars"] 121 | 122 | 123 | @labeling_function( 124 | resources=dict(low_rating_strs=low_rating_strs, high_rating_strs=high_rating_strs) 125 | ) 126 | def stars_in_review(x, low_rating_strs, high_rating_strs): 127 | if not isinstance(x.review_text, str): 128 | return ABSTAIN 129 | for low_rating_str in low_rating_strs: 130 | if low_rating_str in x.review_text.lower(): 131 | return NEGATIVE 132 | for high_rating_str in high_rating_strs: 133 | if high_rating_str in x.review_text.lower(): 134 | return POSITIVE 135 | return ABSTAIN 136 | 137 | 138 | # %% [markdown] 139 | # We can also run [TextBlob](https://textblob.readthedocs.io/en/dev/index.html), a tool that provides a pretrained sentiment analyzer, on the reviews, and use its polarity and subjectivity scores to estimate the user's rating for the book. 140 | # As usual, these thresholds were picked by analyzing the score distributions and running error analysis. 141 | 142 | # %% 143 | from snorkel.preprocess import preprocessor 144 | from textblob import TextBlob 145 | 146 | 147 | @preprocessor(memoize=True) 148 | def textblob_polarity(x): 149 | if isinstance(x.review_text, str): 150 | x.blob = TextBlob(x.review_text) 151 | else: 152 | x.blob = None 153 | return x 154 | 155 | 156 | # Label high polarity reviews as positive. 157 | @labeling_function(pre=[textblob_polarity]) 158 | def polarity_positive(x): 159 | if x.blob: 160 | if x.blob.polarity > 0.3: 161 | return POSITIVE 162 | return ABSTAIN 163 | 164 | 165 | # Label high subjectivity reviews as positive. 166 | @labeling_function(pre=[textblob_polarity]) 167 | def subjectivity_positive(x): 168 | if x.blob: 169 | if x.blob.subjectivity > 0.75: 170 | return POSITIVE 171 | return ABSTAIN 172 | 173 | 174 | # Label low polarity reviews as negative. 175 | @labeling_function(pre=[textblob_polarity]) 176 | def polarity_negative(x): 177 | if x.blob: 178 | if x.blob.polarity < 0.0: 179 | return NEGATIVE 180 | return ABSTAIN 181 | 182 | 183 | # %% {"tags": ["md-exclude-output"]} 184 | from snorkel.labeling import PandasLFApplier, LFAnalysis 185 | 186 | lfs = [ 187 | stars_in_review, 188 | shared_first_author, 189 | polarity_positive, 190 | subjectivity_positive, 191 | polarity_negative, 192 | ] 193 | 194 | applier = PandasLFApplier(lfs) 195 | L_dev = applier.apply(df_dev) 196 | 197 | # %% 198 | LFAnalysis(L_dev, lfs).lf_summary(df_dev.rating.values) 199 | 200 | # %% [markdown] 201 | # ### Applying labeling functions to the training set 202 | # 203 | # We apply the labeling functions to the training set, and then filter out data points unlabeled by any LF to form our final training set. 204 | 205 | # %% {"tags": ["md-exclude-output"]} 206 | from snorkel.labeling.model import LabelModel 207 | 208 | L_train = applier.apply(df_train) 209 | label_model = LabelModel(cardinality=2, verbose=True) 210 | label_model.fit(L_train, n_epochs=5000, seed=123, log_freq=20, lr=0.01) 211 | preds_train = label_model.predict(L_train) 212 | 213 | # %% {"tags": ["md-exclude-output"]} 214 | from snorkel.labeling import filter_unlabeled_dataframe 215 | 216 | df_train_filtered, preds_train_filtered = filter_unlabeled_dataframe( 217 | df_train, preds_train, L_train 218 | ) 219 | df_train_filtered["rating"] = preds_train_filtered 220 | 221 | # %% [markdown] 222 | # ### Rating Prediction Model 223 | # We write a Keras model for predicting ratings given a user's book list and a book (which is being rated). 224 | # The model represents the list of books the user interacted with, `books_idxs`, by learning an embedding for each idx, and averaging the embeddings in `book_idxs`. 225 | # It learns another embedding for the `book_idx`, the book to be rated. 226 | # Then it concatenates the two embeddings and uses an [MLP](https://en.wikipedia.org/wiki/Multilayer_perceptron) to compute the probability of the `rating` being 1. 227 | # This type of model is common in large-scale recommender systems, for example, the [YouTube recommender system](https://ai.google/research/pubs/pub45530). 228 | 229 | # %% 230 | import numpy as np 231 | import tensorflow as tf 232 | from utils import precision_batch, recall_batch, f1_batch 233 | 234 | n_books = max([max(df.book_idx) for df in [df_train, df_test, df_dev, df_valid]]) 235 | 236 | 237 | # Keras model to predict rating given book_idxs and book_idx. 238 | def get_model(embed_dim=64, hidden_layer_sizes=[32]): 239 | # Compute embedding for book_idxs. 240 | len_book_idxs = tf.keras.layers.Input([]) 241 | book_idxs = tf.keras.layers.Input([None]) 242 | # book_idxs % n_books is to prevent crashing if a book_idx in book_idxs is > n_books. 243 | book_idxs_emb = tf.keras.layers.Embedding(n_books, embed_dim)(book_idxs % n_books) 244 | book_idxs_emb = tf.math.divide( 245 | tf.keras.backend.sum(book_idxs_emb, axis=1), tf.expand_dims(len_book_idxs, 1) 246 | ) 247 | # Compute embedding for book_idx. 248 | book_idx = tf.keras.layers.Input([]) 249 | book_idx_emb = tf.keras.layers.Embedding(n_books, embed_dim)(book_idx) 250 | input_layer = tf.keras.layers.concatenate([book_idxs_emb, book_idx_emb], 1) 251 | # Build Multi Layer Perceptron on input layer. 252 | cur_layer = input_layer 253 | for size in hidden_layer_sizes: 254 | tf.keras.layers.Dense(size, activation=tf.nn.relu)(cur_layer) 255 | output_layer = tf.keras.layers.Dense(1, activation=tf.nn.sigmoid)(cur_layer) 256 | # Create and compile keras model. 257 | model = tf.keras.Model( 258 | inputs=[len_book_idxs, book_idxs, book_idx], outputs=[output_layer] 259 | ) 260 | model.compile( 261 | "Adagrad", 262 | "binary_crossentropy", 263 | metrics=["accuracy", f1_batch, precision_batch, recall_batch], 264 | ) 265 | return model 266 | 267 | 268 | # %% [markdown] 269 | # We use triples of (`book_idxs`, `book_idx`, `rating`) from our dataframes as training data points. In addition, we want to train the model to recognize when a user will not read a book. To create data points for that, we randomly sample a `book_id` not in `book_idxs` and use that with a `rating` of 0 as a _random negative_ data point. We create one such _random negative_ data point for every positive (`rating` 1) data point in our dataframe so that positive and negative data points are roughly balanced. 270 | 271 | # %% 272 | # Generator to turn dataframe into data points. 273 | def get_data_points_generator(df): 274 | def generator(): 275 | for book_idxs, book_idx, rating in zip(df.book_idxs, df.book_idx, df.rating): 276 | # Remove book_idx from book_idxs so the model can't just look it up. 277 | book_idxs = tuple(filter(lambda x: x != book_idx, book_idxs)) 278 | yield { 279 | "len_book_idxs": len(book_idxs), 280 | "book_idxs": book_idxs, 281 | "book_idx": book_idx, 282 | "label": rating, 283 | } 284 | if rating == 1: 285 | # Generate a random negative book_id not in book_idxs. 286 | random_negative = np.random.randint(0, n_books) 287 | while random_negative in book_idxs: 288 | random_negative = np.random.randint(0, n_books) 289 | yield { 290 | "len_book_idxs": len(book_idxs), 291 | "book_idxs": book_idxs, 292 | "book_idx": random_negative, 293 | "label": 0, 294 | } 295 | 296 | return generator 297 | 298 | 299 | def get_data_tensors(df): 300 | # Use generator to get data points each epoch, along with shuffling and batching. 301 | padded_shapes = { 302 | "len_book_idxs": [], 303 | "book_idxs": [None], 304 | "book_idx": [], 305 | "label": [], 306 | } 307 | dataset = ( 308 | tf.data.Dataset.from_generator( 309 | get_data_points_generator(df), {k: tf.int64 for k in padded_shapes} 310 | ) 311 | .shuffle(123) 312 | .repeat(None) 313 | .padded_batch(batch_size=256, padded_shapes=padded_shapes) 314 | ) 315 | tensor_dict = tf.compat.v1.data.make_one_shot_iterator(dataset).get_next() 316 | return ( 317 | ( 318 | tensor_dict["len_book_idxs"], 319 | tensor_dict["book_idxs"], 320 | tensor_dict["book_idx"], 321 | ), 322 | tensor_dict["label"], 323 | ) 324 | 325 | 326 | # %% [markdown] 327 | # We now train the model on our combined training data (data labeled by LFs plus dev data). 328 | # 329 | # %% {"tags": ["md-exclude-output"]} 330 | from utils import get_n_epochs 331 | 332 | model = get_model() 333 | 334 | X_train, Y_train = get_data_tensors(df_train_filtered) 335 | X_valid, Y_valid = get_data_tensors(df_valid) 336 | model.fit( 337 | X_train, 338 | Y_train, 339 | steps_per_epoch=300, 340 | validation_data=(X_valid, Y_valid), 341 | validation_steps=40, 342 | epochs=get_n_epochs(), 343 | verbose=1, 344 | ) 345 | # %% [markdown] 346 | # Finally, we evaluate the model's predicted ratings on our test data. 347 | # 348 | # %% 349 | X_test, Y_test = get_data_tensors(df_test) 350 | _ = model.evaluate(X_test, Y_test, steps=30) 351 | 352 | # %% [markdown] 353 | # Our model has generalized quite well to our test set! 354 | # Note that we should additionally measure ranking metrics, like precision@10, before deploying to production. 355 | 356 | # %% [markdown] 357 | # ## Summary 358 | # 359 | # In this tutorial, we showed one way to use Snorkel for recommendations. 360 | # We used book metadata and review text to create LFs that estimate user ratings. 361 | # We used Snorkel's `LabelModel` to combine the outputs of those LFs. 362 | # Finally, we trained a model to predict whether a user will read and like a given book (and therefore what books should be recommended to the user) based only on what books the user has interacted with in the past. 363 | # 364 | # Here we demonstrated one way to use Snorkel for training a recommender system. 365 | # Note, however, that this approach could easily be adapted to take advantage of additional information as it is available (e.g., user profile data, denser user ratings, and so on.) 366 | -------------------------------------------------------------------------------- /spam/02_spam_data_augmentation_tutorial.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # %% [markdown] 3 | # # 📈 Snorkel Intro Tutorial: Data Augmentation 4 | 5 | # %% [markdown] 6 | # In this tutorial, we will walk through the process of using *transformation functions* (TFs) to perform data augmentation. 7 | # Like the labeling tutorial, our goal is to train a classifier to YouTube comments as `SPAM` or `HAM` (not spam). 8 | # In the [previous tutorial](https://github.com/snorkel-team/snorkel-tutorials/blob/master/spam/01_spam_tutorial.ipynb), 9 | # we demonstrated how to label training sets programmatically with Snorkel. 10 | # In this tutorial, we'll assume that step has already been done, and start with labeled training data, 11 | # which we'll aim to augment using transformation functions. 12 | # 13 | # %% [markdown] {"tags": ["md-exclude"]} 14 | # * For more details on the task, check out the [labeling tutorial](https://github.com/snorkel-team/snorkel-tutorials/blob/master/spam/01_spam_tutorial.ipynb) 15 | # * For an overview of Snorkel, visit [snorkel.org](https://snorkel.org) 16 | # * You can also check out the [Snorkel API documentation](https://snorkel.readthedocs.io/) 17 | # 18 | # %% [markdown] 19 | # Data augmentation is a popular technique for increasing the size of labeled training sets by applying class-preserving transformations to create copies of labeled data points. 20 | # In the image domain, it is a crucial factor in almost every state-of-the-art result today and is quickly gaining 21 | # popularity in text-based applications. 22 | # Snorkel models the data augmentation process by applying user-defined *transformation functions* (TFs) in sequence. 23 | # You can learn more about data augmentation in 24 | # [this blog post about our NeurIPS 2017 work on automatically learned data augmentation](https://snorkel.org/blog/tanda/). 25 | # 26 | # The tutorial is divided into four parts: 27 | # 1. **Loading Data**: We load a [YouTube comments dataset](http://www.dt.fee.unicamp.br/~tiago//youtubespamcollection/). 28 | # 2. **Writing Transformation Functions**: We write Transformation Functions (TFs) that can be applied to training data points to generate new training data points. 29 | # 3. **Applying Transformation Functions to Augment Our Dataset**: We apply a sequence of TFs to each training data point, using a random policy, to generate an augmented training set. 30 | # 4. **Training a Model**: We use the augmented training set to train an LSTM model for classifying new comments as `SPAM` or `HAM`. 31 | 32 | # %% [markdown] {"tags": ["md-exclude"]} 33 | # This next cell takes care of some notebook-specific housekeeping. 34 | # You can ignore it. 35 | 36 | # %% {"tags": ["md-exclude"]} 37 | import os 38 | import random 39 | 40 | import numpy as np 41 | 42 | # Make sure we're running from the spam/ directory 43 | if os.path.basename(os.getcwd()) == "snorkel-tutorials": 44 | os.chdir("spam") 45 | 46 | # Turn off TensorFlow logging messages 47 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" 48 | 49 | # For reproducibility 50 | seed = 0 51 | os.environ["PYTHONHASHSEED"] = str(seed) 52 | np.random.seed(0) 53 | random.seed(0) 54 | 55 | # %% [markdown] {"tags": ["md-exclude"]} 56 | # If you want to display all comment text untruncated, change `DISPLAY_ALL_TEXT` to `True` below. 57 | 58 | # %% {"tags": ["md-exclude"]} 59 | import pandas as pd 60 | 61 | 62 | DISPLAY_ALL_TEXT = False 63 | 64 | pd.set_option("display.max_colwidth", 0 if DISPLAY_ALL_TEXT else 50) 65 | 66 | # %% [markdown] {"tags": ["md-exclude"]} 67 | # This next cell makes sure a spaCy English model is downloaded. 68 | # If this is your first time downloading this model, restart the kernel after executing the next cell. 69 | 70 | # %% {"tags": ["md-exclude"]} 71 | # Download the spaCy english model 72 | # ! python -m spacy download en_core_web_sm 73 | 74 | # %% [markdown] 75 | # ## 1. Loading Data 76 | 77 | # %% [markdown] 78 | # We load the Kaggle dataset and create Pandas DataFrame objects for the `train` and `test` sets. 79 | # The two main columns in the DataFrames are: 80 | # * **`text`**: Raw text content of the comment 81 | # * **`label`**: Whether the comment is `SPAM` (1) or `HAM` (0). 82 | # 83 | # For more details, check out the [labeling tutorial](https://github.com/snorkel-team/snorkel-tutorials/blob/master/spam/01_spam_tutorial.ipynb). 84 | 85 | # %% 86 | from utils import load_spam_dataset 87 | 88 | df_train, df_test = load_spam_dataset(load_train_labels=True) 89 | 90 | # We pull out the label vectors for ease of use later 91 | Y_train = df_train["label"].values 92 | Y_test = df_test["label"].values 93 | 94 | 95 | # %% 96 | df_train.head() 97 | 98 | # %% [markdown] 99 | # ## 2. Writing Transformation Functions (TFs) 100 | # 101 | # Transformation functions are functions that can be applied to a training data point to create another valid training data point of the same class. 102 | # For example, for image classification problems, it is common to rotate or crop images in the training data to create new training inputs. 103 | # Transformation functions should be atomic e.g. a small rotation of an image, or changing a single word in a sentence. 104 | # We then compose multiple transformation functions when applying them to training data points. 105 | # 106 | # Common ways to augment text includes replacing words with their synonyms, or replacing names entities with other entities. 107 | # More info can be found 108 | # [here](https://towardsdatascience.com/data-augmentation-in-nlp-2801a34dfc28) or 109 | # [here](https://towardsdatascience.com/these-are-the-easiest-data-augmentation-techniques-in-natural-language-processing-you-can-think-of-88e393fd610). 110 | # Our basic modeling assumption is that applying these operations to a comment generally shouldn't change whether it is `SPAM` or not. 111 | # 112 | # Transformation functions in Snorkel are created with the 113 | # [`transformation_function` decorator](https://snorkel.readthedocs.io/en/master/packages/_autosummary/augmentation/snorkel.augmentation.transformation_function.html#snorkel.augmentation.transformation_function), 114 | # which wraps a function that takes in a single data point and returns a transformed version of the data point. 115 | # If no transformation is possible, a TF can return `None` or the original data point. 116 | # If all the TFs applied to a data point return `None`, the data point won't be included in 117 | # the augmented dataset when we apply our TFs below. 118 | # 119 | # Just like the `labeling_function` decorator, the `transformation_function` decorator 120 | # accepts `pre` argument for `Preprocessor` objects. 121 | # Here, we'll use a 122 | # [`SpacyPreprocessor`](https://snorkel.readthedocs.io/en/master/packages/_autosummary/preprocess/snorkel.preprocess.nlp.SpacyPreprocessor.html#snorkel.preprocess.nlp.SpacyPreprocessor). 123 | 124 | # %% 125 | from snorkel.preprocess.nlp import SpacyPreprocessor 126 | 127 | spacy = SpacyPreprocessor(text_field="text", doc_field="doc", memoize=True) 128 | 129 | # %% 130 | import names 131 | from snorkel.augmentation import transformation_function 132 | 133 | # Pregenerate some random person names to replace existing ones with 134 | # for the transformation strategies below 135 | replacement_names = [names.get_full_name() for _ in range(50)] 136 | 137 | 138 | # Replace a random named entity with a different entity of the same type. 139 | @transformation_function(pre=[spacy]) 140 | def change_person(x): 141 | person_names = [ent.text for ent in x.doc.ents if ent.label_ == "PERSON"] 142 | # If there is at least one person name, replace a random one. Else return None. 143 | if person_names: 144 | name_to_replace = np.random.choice(person_names) 145 | replacement_name = np.random.choice(replacement_names) 146 | x.text = x.text.replace(name_to_replace, replacement_name) 147 | return x 148 | 149 | 150 | # Swap two adjectives at random. 151 | @transformation_function(pre=[spacy]) 152 | def swap_adjectives(x): 153 | adjective_idxs = [i for i, token in enumerate(x.doc) if token.pos_ == "ADJ"] 154 | # Check that there are at least two adjectives to swap. 155 | if len(adjective_idxs) >= 2: 156 | idx1, idx2 = sorted(np.random.choice(adjective_idxs, 2, replace=False)) 157 | # Swap tokens in positions idx1 and idx2. 158 | x.text = " ".join( 159 | [ 160 | x.doc[:idx1].text, 161 | x.doc[idx2].text, 162 | x.doc[1 + idx1 : idx2].text, 163 | x.doc[idx1].text, 164 | x.doc[1 + idx2 :].text, 165 | ] 166 | ) 167 | return x 168 | 169 | 170 | # %% [markdown] 171 | # We add some transformation functions that use `wordnet` from [NLTK](https://www.nltk.org/) to replace different parts of speech with their synonyms. 172 | 173 | # %% {"tags": ["md-exclude-output"]} 174 | import nltk 175 | from nltk.corpus import wordnet as wn 176 | 177 | nltk.download("wordnet") 178 | 179 | 180 | def get_synonym(word, pos=None): 181 | """Get synonym for word given its part-of-speech (pos).""" 182 | synsets = wn.synsets(word, pos=pos) 183 | # Return None if wordnet has no synsets (synonym sets) for this word and pos. 184 | if synsets: 185 | words = [lemma.name() for lemma in synsets[0].lemmas()] 186 | if words[0].lower() != word.lower(): # Skip if synonym is same as word. 187 | # Multi word synonyms in wordnet use '_' as a separator e.g. reckon_with. Replace it with space. 188 | return words[0].replace("_", " ") 189 | 190 | 191 | def replace_token(spacy_doc, idx, replacement): 192 | """Replace token in position idx with replacement.""" 193 | return " ".join([spacy_doc[:idx].text, replacement, spacy_doc[1 + idx :].text]) 194 | 195 | 196 | @transformation_function(pre=[spacy]) 197 | def replace_verb_with_synonym(x): 198 | # Get indices of verb tokens in sentence. 199 | verb_idxs = [i for i, token in enumerate(x.doc) if token.pos_ == "VERB"] 200 | if verb_idxs: 201 | # Pick random verb idx to replace. 202 | idx = np.random.choice(verb_idxs) 203 | synonym = get_synonym(x.doc[idx].text, pos="v") 204 | # If there's a valid verb synonym, replace it. Otherwise, return None. 205 | if synonym: 206 | x.text = replace_token(x.doc, idx, synonym) 207 | return x 208 | 209 | 210 | @transformation_function(pre=[spacy]) 211 | def replace_noun_with_synonym(x): 212 | # Get indices of noun tokens in sentence. 213 | noun_idxs = [i for i, token in enumerate(x.doc) if token.pos_ == "NOUN"] 214 | if noun_idxs: 215 | # Pick random noun idx to replace. 216 | idx = np.random.choice(noun_idxs) 217 | synonym = get_synonym(x.doc[idx].text, pos="n") 218 | # If there's a valid noun synonym, replace it. Otherwise, return None. 219 | if synonym: 220 | x.text = replace_token(x.doc, idx, synonym) 221 | return x 222 | 223 | 224 | @transformation_function(pre=[spacy]) 225 | def replace_adjective_with_synonym(x): 226 | # Get indices of adjective tokens in sentence. 227 | adjective_idxs = [i for i, token in enumerate(x.doc) if token.pos_ == "ADJ"] 228 | if adjective_idxs: 229 | # Pick random adjective idx to replace. 230 | idx = np.random.choice(adjective_idxs) 231 | synonym = get_synonym(x.doc[idx].text, pos="a") 232 | # If there's a valid adjective synonym, replace it. Otherwise, return None. 233 | if synonym: 234 | x.text = replace_token(x.doc, idx, synonym) 235 | return x 236 | 237 | 238 | # %% 239 | tfs = [ 240 | change_person, 241 | swap_adjectives, 242 | replace_verb_with_synonym, 243 | replace_noun_with_synonym, 244 | replace_adjective_with_synonym, 245 | ] 246 | 247 | # %% [markdown] 248 | # Let's check out a few examples of transformed data points to see what our TFs are doing. 249 | 250 | # %% 251 | from utils import preview_tfs 252 | 253 | preview_tfs(df_train, tfs) 254 | 255 | # %% [markdown] 256 | # We notice a couple of things about the TFs. 257 | # 258 | # * Sometimes they make trivial changes (`"website"` to `"web site"` for replace_noun_with_synonym). 259 | # This can still be helpful for training our model, because it teaches the model to be invariant to such small changes. 260 | # * Sometimes they introduce incorrect grammar to the sentence (e.g. `swap_adjectives` swapping `"young"` and `"more"` above). 261 | # 262 | # The TFs are expected to be heuristic strategies that indeed preserve the class most of the time, but 263 | # [don't need to be perfect](https://arxiv.org/pdf/1901.11196.pdf). 264 | # This is especially true when using automated 265 | # [data augmentation techniques](https://snorkel.org/blog/tanda/) 266 | # which can learn to avoid particularly corrupted data points. 267 | # As we'll see below, Snorkel is compatible with such learned augmentation policies. 268 | 269 | # %% [markdown] 270 | # ## 3. Applying Transformation Functions 271 | 272 | # %% [markdown] 273 | # We'll first define a `Policy` to determine what sequence of TFs to apply to each data point. 274 | # We'll start with a [`RandomPolicy`](https://snorkel.readthedocs.io/en/master/packages/_autosummary/augmentation/snorkel.augmentation.RandomPolicy.html) 275 | # that samples `sequence_length=2` TFs to apply uniformly at random per data point. 276 | # The `n_per_original` argument determines how many augmented data points to generate per original data point. 277 | 278 | # %% 279 | from snorkel.augmentation import RandomPolicy 280 | 281 | random_policy = RandomPolicy( 282 | len(tfs), sequence_length=2, n_per_original=2, keep_original=True 283 | ) 284 | 285 | # %% [markdown] 286 | # In some cases, we can do better than uniform random sampling. 287 | # We might have domain knowledge that some TFs should be applied more frequently than others, 288 | # or have trained an [automated data augmentation model](https://snorkel.org/blog/tanda/) 289 | # that learned a sampling distribution for the TFs. 290 | # Snorkel supports this use case with a 291 | # [`MeanFieldPolicy`](https://snorkel.readthedocs.io/en/master/packages/_autosummary/augmentation/snorkel.augmentation.MeanFieldPolicy.html), 292 | # which allows you to specify a sampling distribution for the TFs. 293 | # We give higher probabilities to the `replace_[X]_with_synonym` TFs, since those provide more information to the model. 294 | 295 | # %% 296 | from snorkel.augmentation import MeanFieldPolicy 297 | 298 | mean_field_policy = MeanFieldPolicy( 299 | len(tfs), 300 | sequence_length=2, 301 | n_per_original=2, 302 | keep_original=True, 303 | p=[0.05, 0.05, 0.3, 0.3, 0.3], 304 | ) 305 | 306 | # %% [markdown] 307 | # To apply one or more TFs that we've written to a collection of data points according to our policy, we use a 308 | # [`PandasTFApplier`](https://snorkel.readthedocs.io/en/master/packages/_autosummary/augmentation/snorkel.augmentation.PandasTFApplier.html) 309 | # because our data points are represented with a Pandas DataFrame. 310 | 311 | # %% {"tags": ["md-exclude-output"]} 312 | from snorkel.augmentation import PandasTFApplier 313 | 314 | tf_applier = PandasTFApplier(tfs, mean_field_policy) 315 | df_train_augmented = tf_applier.apply(df_train) 316 | Y_train_augmented = df_train_augmented["label"].values 317 | 318 | # %% 319 | print(f"Original training set size: {len(df_train)}") 320 | print(f"Augmented training set size: {len(df_train_augmented)}") 321 | 322 | # %% [markdown] 323 | # We have almost doubled our dataset using TFs! 324 | # Note that despite `n_per_original` being set to 2, our dataset may not exactly triple in size, 325 | # because sometimes TFs return `None` instead of a new data point 326 | # (e.g. `change_person` when applied to a sentence with no persons). 327 | # If you prefer to have exact proportions for your dataset, you can have TFs that can't perform a 328 | # valid transformation return the original data point rather than `None` (as they do here). 329 | 330 | 331 | # %% [markdown] 332 | # ## 4. Training A Model 333 | # 334 | # Our final step is to use the augmented data to train a model. We train an LSTM (Long Short Term Memory) model, which is a very standard architecture for text processing tasks. 335 | 336 | # %% [markdown] {"tags": ["md-exclude"]} 337 | # The next cell makes Keras results reproducible. You can ignore it. 338 | 339 | # %% {"tags": ["md-exclude"]} 340 | import tensorflow as tf 341 | 342 | session_conf = tf.compat.v1.ConfigProto( 343 | intra_op_parallelism_threads=1, inter_op_parallelism_threads=1 344 | ) 345 | 346 | tf.compat.v1.set_random_seed(0) 347 | sess = tf.compat.v1.Session(graph=tf.compat.v1.get_default_graph(), config=session_conf) 348 | tf.compat.v1.keras.backend.set_session(sess) 349 | 350 | # %% [markdown] 351 | # Now we'll train our LSTM on both the original and augmented datasets to compare performance. 352 | 353 | # %% {"tags": ["md-exclude-output"]} 354 | from utils import featurize_df_tokens, get_keras_lstm 355 | 356 | X_train = featurize_df_tokens(df_train) 357 | X_train_augmented = featurize_df_tokens(df_train_augmented) 358 | X_test = featurize_df_tokens(df_test) 359 | 360 | 361 | def train_and_test(X_train, Y_train, X_test=X_test, Y_test=Y_test, num_buckets=30000): 362 | # Define a vanilla LSTM model with Keras 363 | lstm_model = get_keras_lstm(num_buckets) 364 | lstm_model.fit(X_train, Y_train, epochs=5, verbose=0) 365 | preds_test = lstm_model.predict(X_test)[:, 0] > 0.5 366 | return (preds_test == Y_test).mean() 367 | 368 | 369 | acc_augmented = train_and_test(X_train_augmented, Y_train_augmented) 370 | acc_original = train_and_test(X_train, Y_train) 371 | 372 | # %% 373 | print(f"Test Accuracy (original training data): {100 * acc_original:.1f}%") 374 | print(f"Test Accuracy (augmented training data): {100 * acc_augmented:.1f}%") 375 | 376 | 377 | # %% [markdown] 378 | # So using the augmented dataset indeed improved our model! 379 | # There is a lot more you can do with data augmentation, so try a few ideas 380 | # out on your own! 381 | --------------------------------------------------------------------------------