├── .dockerignore ├── .gitignore ├── .gitmodules ├── Dockerfile ├── Introduction.ipynb ├── README.md ├── bin └── get_wikidata.sh ├── ch02 ├── .gitignore ├── Weakly Supervised Learning - Stack Overflow Tag Labeler.ipynb ├── emr_bootstrap.sh ├── get_questions.spark.py ├── images │ └── kim_cnn_model_architecture.png └── xml_to_parquet.spark.py ├── ch03 ├── Introducing Snorkel.ipynb └── images │ ├── labeling_function_api.png │ ├── snorkel_apis_0.9.5.png │ └── snorkel_tutorial_functions.png ├── ch04 ├── .gitignore ├── Chapter 4 - Github Embeddings.ipynb ├── Chapter 4 - Transfer Learning.ipynb ├── PREREQUISITES.md └── bert.sh ├── ch05 ├── Distant Supervision.ipynb ├── Snorkel.ipynb ├── bad_tags.spark.py ├── label.spark.py └── split_tags.spark.py ├── conda.env.yaml ├── conda.pip.requirements.txt ├── conda.requirements.txt ├── data ├── .exists └── amazon_github_repos.json.bz2 ├── docker-compose.yml ├── download.sh ├── lib └── utils.py ├── paths.json ├── requirements.dev.in ├── requirements.in └── settings.json /.dockerignore: -------------------------------------------------------------------------------- 1 | data 2 | bert 3 | __pycache__ 4 | snorkel 5 | snorkel-tutorials 6 | src 7 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints 2 | venv 3 | wandb 4 | .DS_Store 5 | .vscode 6 | data 7 | nohup.out 8 | models 9 | simple_log.jsonl 10 | __pycache__ 11 | src 12 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "bert"] 2 | path = bert 3 | url = https://github.com/google-research/bert 4 | [submodule "snorkel-tutorials"] 5 | path = snorkel-tutorials 6 | url = https://github.com/snorkel-team/snorkel-tutorials 7 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM continuumio/anaconda3:latest 2 | 3 | # Copy this directory over, only to be over-ridden by the docker-compose volume of same 4 | COPY . /weakly_supervised_learning_code 5 | WORKDIR /weakly_supervised_learning_code 6 | 7 | # Install Python dependencies and setup Jupyter without authentication 8 | RUN pip install -r requirements.in && \ 9 | jupyter notebook --generate-config && \ 10 | echo "c.NotebookApp.token = ''" >> ~/.jupyter/jupyter_notebook_config.py 11 | 12 | # Run Jupyter 13 | CMD jupyter notebook --port=8888 --no-browser --ip=0.0.0.0 --allow-root 14 | -------------------------------------------------------------------------------- /Introduction.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Weakly Supervised Learning\n", 8 | "\n", 9 | "Welcome to *Weakly Supervised Learning* (O'Reilly Media, 2020) by [Russell Jurney](https://linkedin.com/in/russelljurney). \n", 10 | "\n", 11 | "The book's examples are organized by chapter:\n", 12 | "\n", 13 | "* [Chapter 1]()\n", 14 | "* [Chapter 2]()" 15 | ] 16 | } 17 | ], 18 | "metadata": { 19 | "kernelspec": { 20 | "display_name": "Python 3", 21 | "language": "python", 22 | "name": "python3" 23 | }, 24 | "language_info": { 25 | "codemirror_mode": { 26 | "name": "ipython", 27 | "version": 3 28 | }, 29 | "file_extension": ".py", 30 | "mimetype": "text/x-python", 31 | "name": "python", 32 | "nbconvert_exporter": "python", 33 | "pygments_lexer": "ipython3", 34 | "version": "3.7.6" 35 | } 36 | }, 37 | "nbformat": 4, 38 | "nbformat_minor": 2 39 | } 40 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Weakly Supervised Learning 2 | 3 | This is the source code for the book *Weakly Supervied Learning* an incomplete book from 2019-2020 by Russell Jurney. The book itself is open source and can be found at :) 4 | 5 | In my previous book, [Agile Data Science 2.0](http://shop.oreilly.com/product/0636920051619.do) (O’Reilly Media, 2017), I [setup EC2 and Vagrant environments](https://github.com/rjurney/Agile_Data_Code_2) in which to run the book’s code but since 2017 the Python ecosystem has developed to the point that I am going to refrain from providing thorough installation documentation for every requirement. In this book I provide a Docker setup that is easy to use and also provide Anaconda and PyPi environments if you wish the run the code yourself locally. The website for each library is a better resource than I can possibly create, and they are updated and maintained more frequently than this book. I will instead list requirements, link to the project pages and let the reader install the requirements themselves. If you want to use a pre-built environment, use the `Dockerfile` and `docker-compose.yml` files included in the [code repository for the book](https://github.com/rjurney/weakly_supervised_learning_code) will “just work” on any operating system that Docker runs on: Linux, Mac OS X, Windows. 6 | 7 | ## Software Prerequisites 8 | 9 | * Linux, Mac OS X or Windows - any OS with a Docker implementation 10 | * [Git](https://git-scm.com/download) is used to check out the book’s source code 11 | * [Docker](https://www.docker.com/get-started) is used to run the book’s examples in the same environment I wrote them in 12 | 13 | ## Running Docker via `docker-compose` 14 | 15 | To run the examples using `docker-compose` simply run: 16 | 17 | ```bash 18 | docker-compose up --build -d 19 | ``` 20 | 21 | The `--build` builds the container using the local directory the first time you run it. The `-d` puts the Jupyter web server in the background, and is optional. 22 | 23 | Now visit [http://localhost:8888](http://localhost:8888) 24 | 25 | If you run into problems, remove the `-d` argument to run it in the foreground and [file an issue](https://github.com/rjurney/weakly_supervised_learning_code/issues/new) on Github with the command you used and the complete error output. 26 | 27 | ## Running Docker directly via the `Dockerfile` 28 | 29 | You can also build and run the docker image directly via the `docker` command and the `Dockerfile` : 30 | 31 | ```bash 32 | docker build --tag weakly_supervised_learning . 33 | docker container run \ 34 | --publish 8888:8888 \ 35 | --detach \ 36 | --name weakly_supervised_learning \ 37 | -v .:/weakly_supervised_learning_code \ 38 | weakly_supervised_learning 39 | ``` 40 | 41 | Now visit [http://localhost:8888](http://localhost:8888) 42 | 43 | If you run into problems, remove the `--detach` argument to run it in the foreground and [file an issue](https://github.com/rjurney/weakly_supervised_learning_code/issues/new) on Github with the command you used and the complete error output. 44 | 45 | ## Running via Docker Hub 46 | 47 | You can also use Docker Hub to pull and run the image directly: 48 | 49 | ```bash 50 | docker pull rjurney/weakly_supervised_learning 51 | docker run weakly_supervised_learning # add a volume for . 52 | ``` 53 | 54 | Now visit [http://localhost:8888](http://localhost:8888) 55 | 56 | ## Bugs, Errors or other Problems 57 | 58 | If you run into problems, make sure you have the latest code with `git pull origin master` and if it persist then [search the Github issues](https://github.com/rjurney/weakly_supervised_learning_code/issues?utf8=%E2%9C%93&q=is%3Aissue+) for the error. If a fix isn’t in the issues, then [create a ticket](https://github.com/rjurney/weakly_supervised_learning_code/issues/new) and include the command you ran and the complete output of that command. You can find the Book’s issues on Github here: [https://github.com/rjurney/weakly_supervised_learning_code/issues](https://github.com/rjurney/weakly_supervised_learning_code/issues). 59 | 60 | ## Running the Code Locally 61 | 62 | I’ve defined two Python environments for the book using Conda and a Virtual Environment. Once you have setup the requirements, you can easily reproduce the environment in which the book was written and tested. 63 | 64 | ### Software Prerequisites 65 | 66 | The following requirements are needed if you run the code locally: 67 | 68 | * Python 3.7+ - I recommend [Anaconda Python](https://www.anaconda.com/distribution/), but any Python will do 69 | * `conda` or `virtualenv` to recreate the Python environment I wrote the examples in 70 | * Recommended: An NVIDIA graphics card - you can work the examples without one, but CPU training is painfully slow 71 | * Recommended: [CUDA 10.0](https://developer.nvidia.com/cuda-10.0-download-archive) - for GPU acceleration in CuPy and Tensorflow 72 | * Recommended: [cuDNN](https://developer.nvidia.com/cudnn) - for GPU acceleration in Tensorflow 73 | 74 | The file `environment.yml` lists them for the `conda` environment system used by [Anaconda Python](https://www.anaconda.com/distribution/). The library dependencies for the book are also defined in `requirements.in`, which [PyPi - the Python Package Index](https://pypi.org/) can use via the `pip` command to install them. I recommend PyPi users create a [Virtual Environment](https://virtualenv.pypa.io/en/latest/user_guide.html#introduction) to ensure you replicate the book’s environment accurately. 75 | 76 | The examples in the book are run as [Jupyter Notebooks](https://jupyter.org/). Jupyter is included in both `conda` and `pip` environments. 77 | 78 | ### Anaconda Python 3 79 | 80 | To create a `conda` environment for the book, run: 81 | 82 | ```bash 83 | conda env create -f environment.yml 84 | conda activate weak 85 | ``` 86 | 87 | To deactivate the environment, run: 88 | 89 | ```bash 90 | conda deactivate 91 | ``` 92 | 93 | ### Virtual Environment 94 | 95 | To create a Virtual Environment in which to install the PyPi dependencies, run: 96 | 97 | ```bash 98 | pip install --upgrade virtualenv 99 | virtualenv -p `which python3` weak 100 | source weak/bin/activate 101 | pip install -r requirements.in 102 | ``` 103 | 104 | To deactivate the Virtual Environment, run: 105 | 106 | ```bash 107 | source deactivate 108 | ``` 109 | 110 | ### Running Jupyter 111 | 112 | If you’re using Docker, the image will install and run Jupyter for you. If you’re using your own Python environment, you need to run Jupyter: 113 | 114 | ```bash 115 | cd 116 | jupyter notebook & 117 | ``` 118 | 119 | Then visit [http://localhost:8888](http://localhost:8888) and open [Introduction.ipynb](https://github.com/rjurney/weakly_supervised_learning_code/blob/master/Introduction.ipynb) or select the chapter file you want to read and run. 120 | -------------------------------------------------------------------------------- /bin/get_wikidata.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Get all entities in wikidata 4 | curl -Lko data/wikidata/entities-latest-all.json.bz2 https://dumps.wikimedia.org/wikidatawiki/entities/latest-all.json.bz2 5 | 6 | # Get all wikipedia pages 7 | curl -Lko data/wikidata/enwiki-latest-pages-articles-multistream.xml.bz2 https://dumps.wikimedia.org/enwiki/latest/enwiki-latest-pages-articles-multistream.xml.bz2 8 | -------------------------------------------------------------------------------- /ch02/.gitignore: -------------------------------------------------------------------------------- 1 | models 2 | -------------------------------------------------------------------------------- /ch02/emr_bootstrap.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -x -e 3 | 4 | # Setup Python 5 | sudo yum -y install python-devel 6 | 7 | # Install all required modules 8 | sudo `which pip3` install lxml frozendict ipython pandas boto3 bs4 nltk 9 | 10 | # Download nltk data 11 | python3 -m nltk.downloader punkt 12 | python3 -m nltk.downloader stopwords 13 | 14 | # Install Mosh for long running ssh that won't die 15 | sudo yum -y install mosh git 16 | 17 | # Install requirements for Snorkel processing 18 | sudo `which pip3` install beautifulsoup4 dill gensim iso8601 jupyter numpy pandas<0.26.0 pip-tools pyarrow requests s3fs scikit-learn<0.22.0 snorkel spacy textblob textdistance texttable 19 | 20 | # Download another model 21 | sudo python3 -m spacy download en_core_web_lg 22 | 23 | # Set ipython as the default shell for pyspark 24 | export PYSPARK_DRIVER_PYTHON=ipython3 25 | echo "" >> /home/hadoop/.bash_profile 26 | echo "# Set ipython as the default shell for pyspark" >> /home/hadoop/.bash_profile 27 | echo "export PYSPARK_DRIVER_PYTHON=ipython3" >> /home/hadoop/.bash_profile 28 | echo "" >> /home/hadoop/.bash_profile 29 | -------------------------------------------------------------------------------- /ch02/get_questions.spark.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # 4 | # This script extracts the text and code of Stack Overflow questions related to Python. 5 | # 6 | # Run me with: PYSPARK_DRIVER_PYTHON=ipython3 PYSPARK_PYTHON=python3 pyspark 7 | # 8 | 9 | import re 10 | 11 | from pyspark.sql import SparkSession, Row 12 | import pyspark.sql.functions as F 13 | from pyspark.sql.functions import udf 14 | import pyspark.sql.types as T 15 | 16 | 17 | DEBUG = True 18 | 19 | 20 | # 21 | # Initialize Spark with dynamic allocation enabled to (hopefully) use less RAM 22 | # 23 | spark = SparkSession.builder\ 24 | .appName('Weakly Supervised Learning - Extract Questions')\ 25 | .getOrCreate() 26 | sc = spark.sparkContext 27 | 28 | 29 | # 30 | # Get answered questions and not their answers 31 | # 32 | posts = spark.read.parquet('s3://stackoverflow-events/2020-06-01/Posts.parquet') 33 | posts.show(3) 34 | 35 | if DEBUG: 36 | print('Total posts count: {:,}'.format( 37 | posts.count() 38 | )) 39 | 40 | # Questions are posts without a parent ID 41 | questions = posts.filter(posts.ParentId.isNull()) 42 | 43 | if DEBUG: 44 | print( 45 | f'Total questions count: {questions.count():,}' 46 | ) 47 | 48 | # Quality questions have at least one answer and at least one vote 49 | quality_questions = questions.filter(posts.AnswerCount > 0)\ 50 | .filter(posts.Score > 1) 51 | 52 | if DEBUG: 53 | print(f'Quality questions count: {quality_questions.count():,}') 54 | 55 | # Combine title with body 56 | tb_questions = quality_questions.withColumn( 57 | 'Title_Body', 58 | F.concat( 59 | F.col("Title"), 60 | F.lit(" "), 61 | F.col("Body") 62 | ), 63 | ) 64 | 65 | # Split the tags and replace the Tags column 66 | @udf(T.ArrayType(T.StringType())) 67 | def split_tags(tag_string): 68 | return re.sub('[<>]', ' ', tag_string).split() 69 | 70 | # Just look at Python questions 71 | python_questions = tb_questions.filter(tb_questions.Tags.contains('python')) 72 | 73 | if DEBUG: 74 | print(f'Python questions count: {python_questions.count()}') 75 | 76 | # Make tags a list all pretty like 77 | tag_questions = python_questions.withColumn( 78 | 'Tags', 79 | split_tags( 80 | F.col('Tags') 81 | ) 82 | ) 83 | 84 | # Show 5 records' Title and Tag fields, full field length 85 | tag_questions.select('Title', 'Tags').show() 86 | 87 | # Write all questions to a Parquet file 88 | tag_questions\ 89 | .write.mode('overwrite')\ 90 | .parquet('s3://stackoverflow-events/2020-06-01/PythonQuestions.parquet') 91 | -------------------------------------------------------------------------------- /ch02/images/kim_cnn_model_architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rjurney/weakly_supervised_learning_code/1b1c2b1336384b0d4ca47846d52b459d61468197/ch02/images/kim_cnn_model_architecture.png -------------------------------------------------------------------------------- /ch02/xml_to_parquet.spark.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # 4 | # Convert the Stack Overflow data from XML format to Parquet format for performance reasons. 5 | # Run me with: PYSPARK_DRIVER_PYTHON=ipython3 PYSPARK_PYTHON=python3 pyspark --packages com.databricks:spark-xml_2.11:0.9.0 6 | # 7 | 8 | import json 9 | 10 | from pyspark.sql import SparkSession 11 | import pyspark.sql.functions as F 12 | 13 | 14 | # Initialize PySpark 15 | spark = SparkSession.builder.appName('Weakly Supervised Learning - Convert XML to Parquet').getOrCreate() 16 | sc = spark.sparkContext 17 | 18 | 19 | def remove_prefix(df): 20 | """Remove the _ prefix that Spark-XML adds to all attributes""" 21 | field_names = [x.name for x in df.schema] 22 | new_field_names = [x[1:] for x in field_names] 23 | s = [] 24 | 25 | # Substitute the old name for the new one 26 | for old, new in zip(field_names, new_field_names): 27 | s.append( 28 | F.col(old).alias(new) 29 | ) 30 | return df.select(s) 31 | 32 | 33 | # Use Spark-XML to split the XML file into records 34 | posts_df = spark.read.format('xml')\ 35 | .options(rowTag='row')\ 36 | .options(rootTag='posts')\ 37 | .load('s3://stackoverflow-events/2020-06-01/Posts.xml') 38 | 39 | # Remove the _ prefix from field names 40 | posts_df = remove_prefix(posts_df) 41 | 42 | # Write the DataFrame out to Parquet format 43 | posts_df.write\ 44 | .mode('overwrite')\ 45 | .parquet('s3://stackoverflow-events/2020-06-01/Posts.parquet') 46 | 47 | 48 | # Use Spark-XML to split the XML file into records 49 | users_df = spark.read.format('xml')\ 50 | .options(rowTag='row')\ 51 | .options(rootTag='users')\ 52 | .load('s3://stackoverflow-events/2020-06-01/Users.xml') 53 | 54 | # Remove the _ prefix from field names 55 | users_df = remove_prefix(users_df) 56 | 57 | # Write the DataFrame out to Parquet format 58 | users_df.write\ 59 | .mode('overwrite')\ 60 | .parquet('s3://stackoverflow-events/2020-06-01/Users.parquet') 61 | 62 | 63 | # # Use Spark-XML to split the XML file into records 64 | # tags_df = spark.read.format('xml')\ 65 | # .options(rowTag='row')\ 66 | # .options(rootTag='tags')\ 67 | # .load('s3://stackoverflow-events/2020-06-01/Tags.xml') 68 | 69 | # # Remove the _ prefix from field names 70 | # tags_df = remove_prefix(tags_df) 71 | 72 | # # Write the DataFrame out to Parquet format 73 | # tags_df.write\ 74 | # .mode('overwrite')\ 75 | # .parquet('s3://stackoverflow-events/2020-06-01/Tags.parquet') 76 | 77 | 78 | # # Use Spark-XML to split the XML file into records 79 | # badges_df = spark.read.format('xml')\ 80 | # .options(rowTag='row')\ 81 | # .options(rootTag='badges')\ 82 | # .load('s3://stackoverflow-events/2020-06-01/Badges.xml') 83 | 84 | # # Remove the _ prefix from field names 85 | # badges_df = remove_prefix(badges_df) 86 | 87 | # # Write the DataFrame out to Parquet format 88 | # badges_df.write\ 89 | # .mode('overwrite')\ 90 | # .parquet('s3://stackoverflow-events/2020-06-01/Badges.parquet') 91 | 92 | 93 | # # Use Spark-XML to split the XML file into records 94 | # comments_df = spark.read.format('xml')\ 95 | # .options(rowTag='row')\ 96 | # .options(rootTag='comments')\ 97 | # .load('s3://stackoverflow-events/2020-06-01/Comments.xml') 98 | 99 | # # Remove the _ prefix from field names 100 | # comments_df = remove_prefix(comments_df) 101 | 102 | # # Write the DataFrame out to Parquet format 103 | # comments_df.write\ 104 | # .mode('overwrite')\ 105 | # .parquet('s3://stackoverflow-events/2020-06-01/Comments.parquet') 106 | 107 | 108 | # # Use Spark-XML to split the XML file into records 109 | # post_links_df = spark.read.format('xml')\ 110 | # .options(rowTag='row')\ 111 | # .options(rootTag='postlinks')\ 112 | # .load('s3://stackoverflow-events/2020-06-01/PostLinks.xml') 113 | 114 | # # Remove the _ prefix from field names 115 | # post_links_df = remove_prefix(post_links_df) 116 | 117 | # # Write the DataFrame out to Parquet format 118 | # post_links_df.write\ 119 | # .mode('overwrite')\ 120 | # .parquet('s3://stackoverflow-events/2020-06-01/PostLinks.parquet') 121 | -------------------------------------------------------------------------------- /ch03/images/labeling_function_api.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rjurney/weakly_supervised_learning_code/1b1c2b1336384b0d4ca47846d52b459d61468197/ch03/images/labeling_function_api.png -------------------------------------------------------------------------------- /ch03/images/snorkel_apis_0.9.5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rjurney/weakly_supervised_learning_code/1b1c2b1336384b0d4ca47846d52b459d61468197/ch03/images/snorkel_apis_0.9.5.png -------------------------------------------------------------------------------- /ch03/images/snorkel_tutorial_functions.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rjurney/weakly_supervised_learning_code/1b1c2b1336384b0d4ca47846d52b459d61468197/ch03/images/snorkel_tutorial_functions.png -------------------------------------------------------------------------------- /ch04/.gitignore: -------------------------------------------------------------------------------- 1 | models 2 | -------------------------------------------------------------------------------- /ch04/Chapter 4 - Github Embeddings.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Chapter 3 - Github Embeddings\n", 8 | "\n", 9 | "In this notebook we're going to go beyond using pre-trained embeddings and models we download from the internet and start to create our own secondary models that can improve the primary model through transfer learning. We're going to train text and code embeddings based on Github's [CodeSearchNet](https://github.com/rjurney/CodeSearchNet) datasets. They include both doc strings and code for 2 million posts and while they use the data to map from text search queries to code, we'll be using it to create separate [BERT](https://arxiv.org/abs/1810.04805) embeddings to drive our Stack Overflow tagger.\n", 10 | "\n", 11 | "The paper for CodeSearchNet is on arXiv at [CodeSearchNet Challenge: Evaluating the State of Semantic Code Search](https://arxiv.org/abs/1909.09436)." 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": null, 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "import csv\n", 21 | "import gc\n", 22 | "from pathlib import Path\n", 23 | "import os\n", 24 | "import random\n", 25 | "import sys\n", 26 | "import warnings\n", 27 | "\n", 28 | "from bs4 import BeautifulSoup\n", 29 | "from nltk.tokenize.punkt import PunktSentenceTokenizer\n", 30 | "import pandas as pd\n", 31 | "\n", 32 | "random.seed(1337)\n", 33 | "\n", 34 | "# Add parent directory to path\n", 35 | "parent_dir = os.path.dirname(os.getcwd())\n", 36 | "sys.path.append(parent_dir)\n", 37 | "\n", 38 | "from lib.utils import extract_text_plain\n", 39 | "\n", 40 | "# Disable all warnings\n", 41 | "warnings.filterwarnings(\"ignore\")" 42 | ] 43 | }, 44 | { 45 | "cell_type": "markdown", 46 | "metadata": {}, 47 | "source": [ 48 | "## Load CodeSearchNet Data\n", 49 | "\n", 50 | "We load the entire CodeSearchNet dataset for Go, Java, PHP, Python and Ruby. While the code doesn't cover all languages I'm hoping they are diverse enough to handle other languages and so will still help performance." 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": null, 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "df = pd.DataFrame()\n", 60 | "\n", 61 | "# Load all Gzipped JSON Lines files in the data directory\n", 62 | "for filename in Path('../data/CodeSearchNet').glob('**/*.jsonl.gz'):\n", 63 | " new_df = pd.read_json(filename, lines=True)\n", 64 | " df = pd.concat([df, new_df])\n", 65 | " \n", 66 | " # Carefully manage memory\n", 67 | " del new_df\n", 68 | " gc.collect()\n", 69 | "\n", 70 | "df.head()" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": null, 76 | "metadata": {}, 77 | "outputs": [], 78 | "source": [ 79 | "print(\n", 80 | " f'There are {len(df[\"docstring\"].index):,} functions'\n", 81 | ")" 82 | ] 83 | }, 84 | { 85 | "cell_type": "markdown", 86 | "metadata": {}, 87 | "source": [ 88 | "## Extract Text from Docstrings\n", 89 | "\n", 90 | "Docstings can contain HTML, so we parse them and extract text using `BeautifulSoup`." 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": null, 96 | "metadata": {}, 97 | "outputs": [], 98 | "source": [ 99 | "code = df['code']\n", 100 | "docs = df.docstring.apply(lambda x: extract_text_plain(x))" 101 | ] 102 | }, 103 | { 104 | "cell_type": "markdown", 105 | "metadata": {}, 106 | "source": [ 107 | "## Inspect the result of the code removal" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": null, 113 | "metadata": {}, 114 | "outputs": [], 115 | "source": [ 116 | "pd.set_option('max_colwidth', 500)\n", 117 | "doc_df = pd.DataFrame({'docs': docs, 'docstring': df['docstring']})\n", 118 | "\n", 119 | "doc_df.head(3)" 120 | ] 121 | }, 122 | { 123 | "cell_type": "markdown", 124 | "metadata": {}, 125 | "source": [ 126 | "## About BERT\n", 127 | "\n", 128 | "Google BERT is described in the [BERT README](https://github.com/google-research/bert/blob/master/README.md):\n", 129 | "\n", 130 | "> BERT is a method of pre-training language representations, meaning that we train a general-purpose \"language understanding\" model on a large text corpus (like Wikipedia), and then use that model for downstream NLP tasks that we care about (like question answering). BERT outperforms previous methods because it is the first unsupervised, deeply bidirectional system for pre-training NLP.\n", 131 | "\n", 132 | "> Unsupervised means that BERT was trained using only a plain text corpus, which is important because an enormous amount of plain text data is publicly available on the web in many languages.\n", 133 | "\n", 134 | "> Pre-trained representations can also either be context-free or contextual, and contextual representations can further be unidirectional or bidirectional. Context-free models such as word2vec or GloVe generate a single \"word embedding\" representation for each word in the vocabulary, so bank would have the same representation in bank deposit and river bank. Contextual models instead generate a representation of each word that is based on the other words in the sentence.\n", 135 | "\n", 136 | "> BERT was built upon recent work in pre-training contextual representations — including Semi-supervised Sequence Learning, Generative Pre-Training, ELMo, and ULMFit — but crucially these models are all unidirectional or shallowly bidirectional. This means that each word is only contextualized using the words to its left (or right). For example, in the sentence I made a bank deposit the unidirectional representation of bank is only based on I made a but not deposit. Some previous work does combine the representations from separate left-context and right-context models, but only in a \"shallow\" manner. BERT represents \"bank\" using both its left and right context — I made a ... deposit — starting from the very bottom of a deep neural network, so it is deeply bidirectional.\n", 137 | "\n", 138 | "## Generate CSV for BERT\n", 139 | "\n", 140 | "The [Google BERT Github project](https://github.com/google-research/bert) is a submodule to this project, which you can checkout from within this [cloned project](https://github.com/rjurney/weakly_supervised_learning_code) with:\n", 141 | "\n", 142 | "```bash\n", 143 | "git submodule init\n", 144 | "git submodule update\n", 145 | "```\n", 146 | "\n", 147 | "We need to generate CSV in the format that BERT expects, which is:\n", 148 | "\n", 149 | "> Here's how to run the data generation. The input is a plain text file, with one sentence per line. (It is important that these be actual sentences for the \"next sentence prediction\" task). Documents are delimited by empty lines. The output is a set of tf.train.Examples serialized into TFRecord file format." 150 | ] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "execution_count": null, 155 | "metadata": {}, 156 | "outputs": [], 157 | "source": [ 158 | "sentence_tokenizer = PunktSentenceTokenizer()\n", 159 | "sentences = docs.apply(sentence_tokenizer.tokenize)\n", 160 | "\n", 161 | "sentences.head(2)" 162 | ] 163 | }, 164 | { 165 | "cell_type": "code", 166 | "execution_count": null, 167 | "metadata": {}, 168 | "outputs": [], 169 | "source": [ 170 | "with open('../data/sentences.csv', 'w') as f:\n", 171 | " \n", 172 | " current_idx = 0\n", 173 | " for idx, doc in sentences.items():\n", 174 | " # Insert a newline to separate documents\n", 175 | " if idx != current_idx:\n", 176 | " f.write('\\n')\n", 177 | " # Write each sentence exactly as it appared to one line each\n", 178 | " for sentence in doc:\n", 179 | " f.write(sentence.encode('unicode-escape').decode().replace('\\\\\\\\', '\\\\') + '\\n')" 180 | ] 181 | }, 182 | { 183 | "cell_type": "markdown", 184 | "metadata": {}, 185 | "source": [ 186 | "## Using `sentencepiece` to Extract a WordPiece Vocabulary\n", 187 | "\n", 188 | "BERT needs a WordPiece vocabulary file to run, so we need to decide on a number of tokens and then run `sentencepiece` to extract a list of valid tokens.\n", 189 | "\n", 190 | "The `sentencepiece` Pypi library isn't sufficient for our needs, we need to clone the Github repo, build and install the software to create our vocabulary.\n", 191 | "\n", 192 | "Make sure you're in the root directory of this project and run:\n", 193 | "\n", 194 | "```bash\n", 195 | "git clone https://github.com/google/sentencepiece\n", 196 | "cd sentencepiece\n", 197 | "\n", 198 | "mkdir build\n", 199 | "cd build\n", 200 | "cmake ..\n", 201 | "make -j $(nproc)\n", 202 | "sudo make install\n", 203 | "sudo ldconfig -v\n", 204 | "```\n", 205 | "\n", 206 | "Now we can use `sp_train` to create a vocabulary of our 4.7 million sentences." 207 | ] 208 | }, 209 | { 210 | "cell_type": "code", 211 | "execution_count": null, 212 | "metadata": {}, 213 | "outputs": [], 214 | "source": [ 215 | "%%bash\n", 216 | "\n", 217 | "cd ../models\n", 218 | "spm_train --input=\"../data/sentences.csv\" --model_prefix=wsl --vocab_size=20000\n", 219 | "\n", 220 | "# Add the [CLS], [SEP], [UNK] and [MASK] tags, or pre-training will error out\n", 221 | "echo -e \"[CLS]\\t0\\n[SEP]\\t0\\n[UNK]\\t0\\n[MASK]\\t0\\n$(cat wsl.vocab)\" > wsl.vocab\n", 222 | "\n", 223 | "# Remove the numbers, just retain the tag vocabulary\n", 224 | "cat wsl.vocab | cut -d$'\\t' -f1 > wsl.stripped.vocab" 225 | ] 226 | }, 227 | { 228 | "cell_type": "markdown", 229 | "metadata": {}, 230 | "source": [] 231 | }, 232 | { 233 | "cell_type": "markdown", 234 | "metadata": {}, 235 | "source": [ 236 | "## Using BERT to Pretrain a Language Model\n", 237 | "\n", 238 | "Next we use the WordPiece vocabulary to pre-train a BERT model that we will then use, as a tranfer learning strategy, to encode the text of Stack Overflow questions.\n", 239 | "\n", 240 | "### Creating a BERT conda environment\n", 241 | "\n", 242 | "It is not possible to create a new conda environment from which to install `tensorflow==1.14.0`, which BERT needs, so you will need to run this code outside of this notebook, from the root directory of this project.\n", 243 | "\n", 244 | "\n", 245 | "```bash\n", 246 | "conda create -y -n bert python=3.7.4\n", 247 | "conda init bash\n", 248 | "```\n", 249 | "\n", 250 | "Now in a new shell, change directory to the root of project:\n", 251 | "\n", 252 | "```bash\n", 253 | "cd /path/to/weakly_supervised_learning_code\n", 254 | "```\n", 255 | "\n", 256 | "Now run:\n", 257 | "\n", 258 | "```bash\n", 259 | "conda activate bert\n", 260 | "pip install tensorflow-gpu==1.14.0\n", 261 | "```\n", 262 | "\n", 263 | "### Creating BERT Pre-Training Data\n", 264 | "\n", 265 | "Before we can train a BERT model or extract static embedding values we need to create the pre-training data the model uses to train. The output file will be 20GB, so make sure you have the space available!\n", 266 | "\n", 267 | "From the [BERT README](https://github.com/google-research/bert/blob/master/README.md):\n", 268 | "\n", 269 | "> Here's how to run the data generation. The input is a plain text file, with one sentence per line. (It is important that these be actual sentences for the \"next sentence prediction\" task). Documents are delimited by empty lines. The output is a set of tf.train.Examples serialized into TFRecord file format.\n", 270 | "\n", 271 | "We need to configure BERT to use our vocabulary size, so we create a `bert_config.json` file in the `bert/` directory.\n", 272 | "\n", 273 | "```bash\n", 274 | "# Tell BERT how many tokens to use\n", 275 | "echo '{ \"vocab_size\": 20004 }' > bert/bert_config.json \n", 276 | "```\n", 277 | "\n", 278 | "Then we execute the `create_pretraining_data.py` command to pre-train the network.\n", 279 | "\n", 280 | "```bash\n", 281 | "python bert/create_pretraining_data.py \\\n", 282 | " --input_file=data/sentences.csv \\\n", 283 | " --output_file=data/tf_examples.tfrecord \\\n", 284 | " --vocab_file=models/wsl.stripped.vocab \\\n", 285 | " --bert_config_file=bert/bert_config.json \\\n", 286 | " --do_lower_case=False \\\n", 287 | " --max_seq_length=128 \\\n", 288 | " --max_predictions_per_seq=20 \\\n", 289 | " --num_train_steps=20 \\\n", 290 | " --num_warmup_steps=10 \\\n", 291 | " --random_seed=1337 \\\n", 292 | " --learning_rate=2e-5\n", 293 | "```\n", 294 | "\n", 295 | "Now we can run pretraining. If your GPU is only 8GB of RAM, reduce the training batch size to 16 or 24.\n", 296 | "\n", 297 | "```bash\n", 298 | "python bert/run_pretraining.py \\\n", 299 | " --input_file=data/tf_examples.tfrecord \\\n", 300 | " --output_dir=models/bert_pretraining_output \\\n", 301 | " --do_train=True \\\n", 302 | " --do_eval=True \\\n", 303 | " --bert_config_file=bert/bert_config.json \\\n", 304 | " --train_batch_size=32 \\\n", 305 | " --max_seq_length=128 \\\n", 306 | " --max_predictions_per_seq=20 \\\n", 307 | " --num_train_steps=10000 \\\n", 308 | " --num_warmup_steps=10 \\\n", 309 | " --learning_rate=2e-5\n", 310 | "```\n", 311 | "\n", 312 | "Finally, deactivate the conda environment:\n", 313 | "\n", 314 | "```bash\n", 315 | "conda deactivate\n", 316 | "```" 317 | ] 318 | }, 319 | { 320 | "cell_type": "code", 321 | "execution_count": null, 322 | "metadata": {}, 323 | "outputs": [], 324 | "source": [] 325 | }, 326 | { 327 | "cell_type": "code", 328 | "execution_count": null, 329 | "metadata": {}, 330 | "outputs": [], 331 | "source": [] 332 | } 333 | ], 334 | "metadata": { 335 | "kernelspec": { 336 | "display_name": "Python 3", 337 | "language": "python", 338 | "name": "python3" 339 | }, 340 | "language_info": { 341 | "codemirror_mode": { 342 | "name": "ipython", 343 | "version": 3 344 | }, 345 | "file_extension": ".py", 346 | "mimetype": "text/x-python", 347 | "name": "python", 348 | "nbconvert_exporter": "python", 349 | "pygments_lexer": "ipython3", 350 | "version": "3.7.4" 351 | } 352 | }, 353 | "nbformat": 4, 354 | "nbformat_minor": 2 355 | } 356 | -------------------------------------------------------------------------------- /ch04/Chapter 4 - Transfer Learning.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Chapter 3 - Transfer Learning\n", 8 | "\n", 9 | "In this chapter we'll be exploring *transfer learning*, where a model trained for one purpose is used for another. We'll take our initial model and enhance it using text and source code embeddings.\n", 10 | "\n", 11 | "Since we have already explained the workings of this model in the previous chapter, the comments for the model basics have been removed. See [Chapter 2 (add link)]() and [Weakly Supervised Learning - Stack Overflow Tag Labeler.ipynb](../ch02/Weakly%20Supervised%20Learning%20-%20Stack%20Overflow%20Tag%20Labeler.ipynb) to learn more about the model itself." 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 1, 17 | "metadata": {}, 18 | "outputs": [ 19 | { 20 | "name": "stderr", 21 | "output_type": "stream", 22 | "text": [ 23 | "[nltk_data] Downloading package punkt to /home/rjurney/nltk_data...\n", 24 | "[nltk_data] Package punkt is already up-to-date!\n", 25 | "[nltk_data] Downloading package stopwords to\n", 26 | "[nltk_data] /home/rjurney/nltk_data...\n", 27 | "[nltk_data] Package stopwords is already up-to-date!\n" 28 | ] 29 | } 30 | ], 31 | "source": [ 32 | "import gc\n", 33 | "import json\n", 34 | "import math\n", 35 | "import os\n", 36 | "import re\n", 37 | "import sys\n", 38 | "import warnings\n", 39 | "\n", 40 | "import matplotlib.pyplot as plt\n", 41 | "import numpy as np\n", 42 | "import pandas as pd\n", 43 | "import pyarrow\n", 44 | "import tensorflow as tf\n", 45 | "import tensorflow_hub as hub\n", 46 | "\n", 47 | "# Add parent directory to path\n", 48 | "parent_dir = os.path.dirname(os.getcwd())\n", 49 | "sys.path.append(parent_dir)\n", 50 | "\n", 51 | "import lib.utils\n", 52 | "\n", 53 | "# Disable all warnings\n", 54 | "warnings.filterwarnings(\"ignore\")" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": 2, 60 | "metadata": {}, 61 | "outputs": [], 62 | "source": [ 63 | "np.random.seed(seed=1337)" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": 3, 69 | "metadata": {}, 70 | "outputs": [ 71 | { 72 | "name": "stdout", 73 | "output_type": "stream", 74 | "text": [ 75 | "1 or more GPUs is available: True\n", 76 | "GPUs on tap: [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]\n" 77 | ] 78 | } 79 | ], 80 | "source": [ 81 | "gpu_avail = tf.test.is_gpu_available(\n", 82 | " cuda_only=False,\n", 83 | " min_cuda_compute_capability=None\n", 84 | ")\n", 85 | "print(f'1 or more GPUs is available: {gpu_avail}')\n", 86 | "\n", 87 | "avail_gpus = tf.compat.v2.config.experimental.list_physical_devices('GPU')\n", 88 | "print(f'GPUs on tap: {avail_gpus}')" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": 5, 94 | "metadata": {}, 95 | "outputs": [], 96 | "source": [ 97 | "COLUMN_WIDTH = 50\n", 98 | "pd.set_option('display.max_colwidth', COLUMN_WIDTH)" 99 | ] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "execution_count": 6, 104 | "metadata": {}, 105 | "outputs": [], 106 | "source": [ 107 | "BATCH_SIZE = 128\n", 108 | "MAX_LEN = 200\n", 109 | "TOKEN_COUNT = 10000\n", 110 | "EMBED_SIZE = 300\n", 111 | "TEST_SPLIT = 0.3" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": 8, 117 | "metadata": {}, 118 | "outputs": [ 119 | { 120 | "data": { 121 | "text/html": [ 122 | "
\n", 123 | "\n", 136 | "\n", 137 | " \n", 138 | " \n", 139 | " \n", 140 | " \n", 141 | " \n", 142 | " \n", 143 | " \n", 144 | " \n", 145 | " \n", 146 | " \n", 147 | " \n", 148 | " \n", 149 | " \n", 150 | " \n", 151 | " \n", 152 | " \n", 153 | " \n", 154 | " \n", 155 | " \n", 156 | " \n", 157 | " \n", 158 | " \n", 159 | " \n", 160 | " \n", 161 | " \n", 162 | " \n", 163 | " \n", 164 | " \n", 165 | " \n", 166 | " \n", 167 | " \n", 168 | " \n", 169 | " \n", 170 | " \n", 171 | " \n", 172 | " \n", 173 | " \n", 174 | " \n", 175 | " \n", 176 | " \n", 177 | " \n", 178 | " \n", 179 | " \n", 180 | " \n", 181 | " \n", 182 | " \n", 183 | " \n", 184 | " \n", 185 | " \n", 186 | " \n", 187 | " \n", 188 | " \n", 189 | " \n", 190 | " \n", 191 | " \n", 192 | " \n", 193 | " \n", 194 | " \n", 195 | " \n", 196 | " \n", 197 | " \n", 198 | " \n", 199 | " \n", 200 | " \n", 201 | " \n", 202 | " \n", 203 | " \n", 204 | " \n", 205 | " \n", 206 | " \n", 207 | " \n", 208 | " \n", 209 | " \n", 210 | " \n", 211 | " \n", 212 | " \n", 213 | "
_Bodylabel_0label_1label_2label_3label_4label_5label_6label_7label_8...label_776label_777label_778label_779label_780label_781label_782label_783label_784label_785
0[How, animate, Flutter, layout, keyboard, appe...000000000...0000000000
1[Creating, Carousel, using, FutureBuilder, I, ...000000000...0000000000
\n", 214 | "

2 rows × 787 columns

\n", 215 | "
" 216 | ], 217 | "text/plain": [ 218 | " _Body label_0 label_1 \\\n", 219 | "0 [How, animate, Flutter, layout, keyboard, appe... 0 0 \n", 220 | "1 [Creating, Carousel, using, FutureBuilder, I, ... 0 0 \n", 221 | "\n", 222 | " label_2 label_3 label_4 label_5 label_6 label_7 label_8 ... \\\n", 223 | "0 0 0 0 0 0 0 0 ... \n", 224 | "1 0 0 0 0 0 0 0 ... \n", 225 | "\n", 226 | " label_776 label_777 label_778 label_779 label_780 label_781 \\\n", 227 | "0 0 0 0 0 0 0 \n", 228 | "1 0 0 0 0 0 0 \n", 229 | "\n", 230 | " label_782 label_783 label_784 label_785 \n", 231 | "0 0 0 0 0 \n", 232 | "1 0 0 0 0 \n", 233 | "\n", 234 | "[2 rows x 787 columns]" 235 | ] 236 | }, 237 | "execution_count": 8, 238 | "metadata": {}, 239 | "output_type": "execute_result" 240 | } 241 | ], 242 | "source": [ 243 | "# Tag limit defines which dataset to load - those with tags having at least 50K, 20K, 10K, 5K or 2K instances\n", 244 | "TAG_LIMIT = 2000\n", 245 | "\n", 246 | "# Pre-computed sorted list of tag/index pairs\n", 247 | "sorted_all_tags = json.load(open(f'../data/stackoverflow/sorted_all_tags.{TAG_LIMIT}.json'))\n", 248 | "max_index = sorted_all_tags[-1][0] + 1\n", 249 | "\n", 250 | "# Load the parquet file using pyarrow for this tag limit, using the sorted tag index to specify the columns\n", 251 | "posts_df = pd.read_parquet(\n", 252 | " f'../data/stackoverflow/Questions.Stratified.Final.{TAG_LIMIT}.parquet',\n", 253 | " columns=['_Body'] + ['label_{}'.format(i) for i in range(0, max_index)],\n", 254 | " engine='pyarrow'\n", 255 | ")\n", 256 | "posts_df.head(2)" 257 | ] 258 | }, 259 | { 260 | "cell_type": "code", 261 | "execution_count": 9, 262 | "metadata": {}, 263 | "outputs": [ 264 | { 265 | "name": "stdout", 266 | "output_type": "stream", 267 | "text": [ 268 | "1,554,788 Stack Overflow questions with a tag having at least 2,000 occurrences\n" 269 | ] 270 | } 271 | ], 272 | "source": [ 273 | "print(\n", 274 | " '{:,} Stack Overflow questions with a tag having at least 2,000 occurrences'.format(\n", 275 | " len(posts_df.index)\n", 276 | " )\n", 277 | ")" 278 | ] 279 | }, 280 | { 281 | "cell_type": "code", 282 | "execution_count": 10, 283 | "metadata": {}, 284 | "outputs": [ 285 | { 286 | "name": "stdout", 287 | "output_type": "stream", 288 | "text": [ 289 | "Non-zero rows: 1,554,788, Total rows: 1,554,788, Non-zero ratio: 1.0, Least tags: 1, Most tags: 6\n" 290 | ] 291 | } 292 | ], 293 | "source": [ 294 | "test_matrix = posts_df[[f'label_{i}' for i in range(0, max_index)]].as_matrix()\n", 295 | "\n", 296 | "tests = np.count_nonzero(test_matrix.sum(axis=1)), \\\n", 297 | " test_matrix.sum(axis=1).shape[0], \\\n", 298 | " test_matrix.sum(axis=1).min(), \\\n", 299 | " test_matrix.sum(axis=1).max()\n", 300 | "\n", 301 | "print(f'Non-zero rows: {tests[0]:,}, Total rows: {tests[1]:,}, Non-zero ratio: {tests[0]/tests[1]:,}, Least tags: {tests[2]:,}, Most tags: {tests[3]:,}')" 302 | ] 303 | }, 304 | { 305 | "cell_type": "code", 306 | "execution_count": 11, 307 | "metadata": {}, 308 | "outputs": [], 309 | "source": [ 310 | "tag_index = json.load(open(f'../data/stackoverflow/tag_index.{TAG_LIMIT}.json'))\n", 311 | "index_tag = json.load(open(f'../data/stackoverflow/index_tag.{TAG_LIMIT}.json'))\n", 312 | "\n", 313 | "# Sanity check the different files\n", 314 | "assert( len(tag_index.keys()) == len(index_tag.keys()) == len(sorted_all_tags) )" 315 | ] 316 | }, 317 | { 318 | "cell_type": "code", 319 | "execution_count": 12, 320 | "metadata": {}, 321 | "outputs": [ 322 | { 323 | "name": "stdout", 324 | "output_type": "stream", 325 | "text": [ 326 | "Highest Factor: 60 Training Count: 1,536,000\n" 327 | ] 328 | } 329 | ], 330 | "source": [ 331 | "# Convert label columns to numpy array\n", 332 | "labels = posts_df[list(posts_df.columns)[1:]].to_numpy()\n", 333 | "\n", 334 | "# Training_count must be a multiple of the BATCH_SIZE times the MAX_LEN for the Elmo embedding layer\n", 335 | "highest_factor = math.floor(len(posts_df.index) / (BATCH_SIZE * MAX_LEN))\n", 336 | "training_count = highest_factor * BATCH_SIZE * MAX_LEN\n", 337 | "print('Highest Factor: {:,} Training Count: {:,}'.format(highest_factor, training_count))\n", 338 | "\n", 339 | "documents = []\n", 340 | "for body in posts_df[0:training_count]['_Body'].values.tolist():\n", 341 | " words = body.tolist()\n", 342 | " documents.append(' '.join(words))\n", 343 | "\n", 344 | "labels = labels[0:training_count]\n", 345 | "\n", 346 | "# Conserve RAM\n", 347 | "del posts_df\n", 348 | "gc.collect()\n", 349 | "\n", 350 | "# Lengths for x and y match\n", 351 | "assert( len(documents) == training_count == labels.shape[0] )" 352 | ] 353 | }, 354 | { 355 | "cell_type": "code", 356 | "execution_count": 13, 357 | "metadata": {}, 358 | "outputs": [ 359 | { 360 | "data": { 361 | "text/plain": [ 362 | "(1536000, 200)" 363 | ] 364 | }, 365 | "execution_count": 13, 366 | "metadata": {}, 367 | "output_type": "execute_result" 368 | } 369 | ], 370 | "source": [ 371 | "from tensorflow.keras.preprocessing.text import Tokenizer\n", 372 | "from tensorflow.keras.preprocessing.sequence import pad_sequences\n", 373 | "\n", 374 | "tokenizer = Tokenizer(\n", 375 | " num_words=TOKEN_COUNT + 1,\n", 376 | " oov_token='__PAD__'\n", 377 | ")\n", 378 | "tokenizer.fit_on_texts(documents)\n", 379 | "tokenizer.word_index = {e:i for e,i in tokenizer.word_index.items() if i <= TOKEN_COUNT}\n", 380 | "\n", 381 | "sequences = tokenizer.texts_to_sequences(documents)\n", 382 | "\n", 383 | "padded_sequences = pad_sequences(\n", 384 | " sequences,\n", 385 | " maxlen=MAX_LEN,\n", 386 | " dtype='int32',\n", 387 | " padding='post',\n", 388 | " truncating='post',\n", 389 | " value=0,\n", 390 | ")\n", 391 | "tokenizer.sequences_to_matrix(padded_sequences, mode='tfidf')\n", 392 | "\n", 393 | "# Conserve RAM\n", 394 | "del documents\n", 395 | "del sequences\n", 396 | "gc.collect()\n", 397 | "\n", 398 | "# Verify that all padded documents are now the same length\n", 399 | "assert( min([len(x) for x in padded_sequences]) == MAX_LEN == max([len(x) for x in padded_sequences]) )\n", 400 | "\n", 401 | "padded_sequences.shape" 402 | ] 403 | }, 404 | { 405 | "cell_type": "markdown", 406 | "metadata": {}, 407 | "source": [ 408 | "## Load GloVe Embeddings\n", 409 | "\n", 410 | "Stanford defines [GloVe Embeddings](https://nlp.stanford.edu/projects/glove/) as:\n", 411 | "\n", 412 | "> GloVe is an unsupervised learning algorithm for obtaining vector representations for words. Training is performed on aggregated global word-word co-occurrence statistics from a corpus, and the resulting representations showcase interesting linear substructures of the word vector space.\n", 413 | "\n", 414 | "We'll try them out to see if they can beat our own embedding, specific to our data." 415 | ] 416 | }, 417 | { 418 | "cell_type": "code", 419 | "execution_count": 14, 420 | "metadata": {}, 421 | "outputs": [], 422 | "source": [ 423 | "def get_coefs(word,*arr): \n", 424 | " return word, np.asarray(arr, dtype='float32')\n", 425 | "\n", 426 | "embeddings_index = dict(get_coefs(*o.strip().split()) for o in open('../data/GloVe/glove.6B.300d.txt'))" 427 | ] 428 | }, 429 | { 430 | "cell_type": "code", 431 | "execution_count": 15, 432 | "metadata": {}, 433 | "outputs": [ 434 | { 435 | "name": "stdout", 436 | "output_type": "stream", 437 | "text": [ 438 | "10000\n" 439 | ] 440 | }, 441 | { 442 | "data": { 443 | "text/plain": [ 444 | "((10000, 300),\n", 445 | " {'__PAD__': 1,\n", 446 | " 'pad': 2,\n", 447 | " 'i': 3,\n", 448 | " 'using': 4,\n", 449 | " 'like': 5,\n", 450 | " 'code': 6,\n", 451 | " 'the': 7,\n", 452 | " 'get': 8,\n", 453 | " 'use': 9,\n", 454 | " 'how': 10,\n", 455 | " 'file': 11,\n", 456 | " 'want': 12,\n", 457 | " 'would': 13,\n", 458 | " 'way': 14,\n", 459 | " 'error': 15,\n", 460 | " 'one': 16,\n", 461 | " 'data': 17,\n", 462 | " 'is': 18,\n", 463 | " '1': 19,\n", 464 | " 'need': 20,\n", 465 | " '2': 21,\n", 466 | " 'following': 22,\n", 467 | " 'problem': 23,\n", 468 | " 'trying': 24,\n", 469 | " 'this': 25,\n", 470 | " 'app': 26,\n", 471 | " 'work': 27,\n", 472 | " 'know': 28,\n", 473 | " 'user': 29,\n", 474 | " 'function': 30,\n", 475 | " 'c': 31,\n", 476 | " 'class': 32,\n", 477 | " 'what': 33,\n", 478 | " 'but': 34,\n", 479 | " 'tried': 35,\n", 480 | " 'application': 36,\n", 481 | " 'example': 37,\n", 482 | " 'so': 38,\n", 483 | " 'also': 39,\n", 484 | " 'method': 40,\n", 485 | " 'time': 41,\n", 486 | " 'set': 42,\n", 487 | " 'new': 43,\n", 488 | " 'server': 44,\n", 489 | " '0': 45,\n", 490 | " 'in': 46,\n", 491 | " 'something': 47,\n", 492 | " 'run': 48,\n", 493 | " 'thanks': 49,\n", 494 | " 'if': 50,\n", 495 | " 'value': 51,\n", 496 | " 'make': 52,\n", 497 | " 'my': 53,\n", 498 | " 'it': 54,\n", 499 | " 'help': 55,\n", 500 | " 'create': 56,\n", 501 | " 'works': 57,\n", 502 | " 'project': 58,\n", 503 | " '3': 59,\n", 504 | " 'first': 60,\n", 505 | " 'could': 61,\n", 506 | " 'however': 62,\n", 507 | " 'page': 63,\n", 508 | " 'working': 64,\n", 509 | " 'find': 65,\n", 510 | " 'see': 66,\n", 511 | " 'object': 67,\n", 512 | " 'list': 68,\n", 513 | " 'files': 69,\n", 514 | " 'question': 70,\n", 515 | " 'two': 71,\n", 516 | " 'when': 72,\n", 517 | " 'java': 73,\n", 518 | " 'type': 74,\n", 519 | " 'here': 75,\n", 520 | " 'add': 76,\n", 521 | " 'possible': 77,\n", 522 | " 'table': 78,\n", 523 | " 'different': 79,\n", 524 | " 'view': 80,\n", 525 | " 'test': 81,\n", 526 | " 'used': 82,\n", 527 | " 'string': 83,\n", 528 | " 'without': 84,\n", 529 | " 'line': 85,\n", 530 | " 'call': 86,\n", 531 | " '4': 87,\n", 532 | " 'web': 88,\n", 533 | " 'change': 89,\n", 534 | " 'any': 90,\n", 535 | " 'seems': 91,\n", 536 | " 'database': 92,\n", 537 | " 'able': 93,\n", 538 | " 'and': 94,\n", 539 | " 'text': 95,\n", 540 | " 'found': 96,\n", 541 | " 'image': 97,\n", 542 | " 'another': 98,\n", 543 | " 'for': 99,\n", 544 | " 'running': 100,\n", 545 | " 'android': 101,\n", 546 | " 'solution': 102,\n", 547 | " 'name': 103,\n", 548 | " 'fine': 104,\n", 549 | " 'output': 105,\n", 550 | " 'getting': 106,\n", 551 | " 'net': 107,\n", 552 | " 'update': 108,\n", 553 | " 'values': 109,\n", 554 | " 'array': 110,\n", 555 | " 'version': 111,\n", 556 | " '5': 112,\n", 557 | " 'edit': 113,\n", 558 | " 'windows': 114,\n", 559 | " 'try': 115,\n", 560 | " 'still': 116,\n", 561 | " 'issue': 117,\n", 562 | " 'access': 118,\n", 563 | " 'anyone': 119,\n", 564 | " 'number': 120,\n", 565 | " 'api': 121,\n", 566 | " 'simple': 122,\n", 567 | " 'python': 123,\n", 568 | " 'build': 124,\n", 569 | " 'can': 125,\n", 570 | " 'right': 126,\n", 571 | " 'even': 127,\n", 572 | " 'query': 128,\n", 573 | " 'php': 129,\n", 574 | " 'read': 130,\n", 575 | " 'case': 131,\n", 576 | " 'script': 132,\n", 577 | " 'please': 133,\n", 578 | " 'e': 134,\n", 579 | " 'command': 135,\n", 580 | " 'created': 136,\n", 581 | " 'wrong': 137,\n", 582 | " 'html': 138,\n", 583 | " 'form': 139,\n", 584 | " 'button': 140,\n", 585 | " 'service': 141,\n", 586 | " 'system': 142,\n", 587 | " 'instead': 143,\n", 588 | " 'sure': 144,\n", 589 | " 'request': 145,\n", 590 | " 'google': 146,\n", 591 | " 'result': 147,\n", 592 | " 'multiple': 148,\n", 593 | " 'model': 149,\n", 594 | " 'write': 150,\n", 595 | " 'inside': 151,\n", 596 | " 'now': 152,\n", 597 | " 'a': 153,\n", 598 | " 'program': 154,\n", 599 | " 'every': 155,\n", 600 | " 'since': 156,\n", 601 | " 'sql': 157,\n", 602 | " 'js': 158,\n", 603 | " 'cannot': 159,\n", 604 | " 'return': 160,\n", 605 | " 'called': 161,\n", 606 | " 'think': 162,\n", 607 | " 'variable': 163,\n", 608 | " 'many': 164,\n", 609 | " 'custom': 165,\n", 610 | " 'id': 166,\n", 611 | " 'really': 167,\n", 612 | " 'url': 168,\n", 613 | " 'library': 169,\n", 614 | " 'message': 170,\n", 615 | " 'xml': 171,\n", 616 | " 'client': 172,\n", 617 | " 'column': 173,\n", 618 | " 'looks': 174,\n", 619 | " 'why': 175,\n", 620 | " 'got': 176,\n", 621 | " 'understand': 177,\n", 622 | " 'based': 178,\n", 623 | " 'process': 179,\n", 624 | " 'controller': 180,\n", 625 | " 'well': 181,\n", 626 | " 'key': 182,\n", 627 | " 'show': 183,\n", 628 | " 'input': 184,\n", 629 | " 'best': 185,\n", 630 | " 'check': 186,\n", 631 | " 'memory': 187,\n", 632 | " 'back': 188,\n", 633 | " 'looking': 189,\n", 634 | " 'order': 190,\n", 635 | " 'go': 191,\n", 636 | " '7': 192,\n", 637 | " '8': 193,\n", 638 | " 'default': 194,\n", 639 | " 'correct': 195,\n", 640 | " 'point': 196,\n", 641 | " 'idea': 197,\n", 642 | " 'second': 198,\n", 643 | " 'etc': 199,\n", 644 | " 'javascript': 200,\n", 645 | " 'within': 201,\n", 646 | " 'start': 202,\n", 647 | " 'currently': 203,\n", 648 | " 'already': 204,\n", 649 | " 'exception': 205,\n", 650 | " 'seem': 206,\n", 651 | " 'via': 207,\n", 652 | " 'event': 208,\n", 653 | " 'as': 209,\n", 654 | " 'json': 210,\n", 655 | " 'field': 211,\n", 656 | " 'load': 212,\n", 657 | " 'added': 213,\n", 658 | " 'much': 214,\n", 659 | " 'x': 215,\n", 660 | " 'source': 216,\n", 661 | " 'property': 217,\n", 662 | " 'main': 218,\n", 663 | " 'specific': 219,\n", 664 | " 'bit': 220,\n", 665 | " 'post': 221,\n", 666 | " 'good': 222,\n", 667 | " 'done': 223,\n", 668 | " 'objects': 224,\n", 669 | " 'better': 225,\n", 670 | " '10': 226,\n", 671 | " 'part': 227,\n", 672 | " '6': 228,\n", 673 | " 'current': 229,\n", 674 | " 'say': 230,\n", 675 | " 'open': 231,\n", 676 | " 'size': 232,\n", 677 | " 'element': 233,\n", 678 | " 'does': 234,\n", 679 | " 'link': 235,\n", 680 | " 'folder': 236,\n", 681 | " 'content': 237,\n", 682 | " 'directory': 238,\n", 683 | " 'single': 239,\n", 684 | " 'results': 240,\n", 685 | " 'users': 241,\n", 686 | " 'click': 242,\n", 687 | " 'studio': 243,\n", 688 | " 'information': 244,\n", 689 | " 'returns': 245,\n", 690 | " 'instance': 246,\n", 691 | " 'always': 247,\n", 692 | " 'far': 248,\n", 693 | " 'put': 249,\n", 694 | " 'end': 250,\n", 695 | " 'date': 251,\n", 696 | " 'display': 252,\n", 697 | " 'jquery': 253,\n", 698 | " 'everything': 254,\n", 699 | " 'path': 255,\n", 700 | " 'methods': 256,\n", 701 | " 'similar': 257,\n", 702 | " 'thread': 258,\n", 703 | " 'loop': 259,\n", 704 | " 'framework': 260,\n", 705 | " 'around': 261,\n", 706 | " 'given': 262,\n", 707 | " 'look': 263,\n", 708 | " 'search': 264,\n", 709 | " 'visual': 265,\n", 710 | " 'anything': 266,\n", 711 | " 'browser': 267,\n", 712 | " 'select': 268,\n", 713 | " 'answer': 269,\n", 714 | " 'row': 270,\n", 715 | " 'functions': 271,\n", 716 | " 'missing': 272,\n", 717 | " 'store': 273,\n", 718 | " 'reference': 274,\n", 719 | " 'someone': 275,\n", 720 | " 'may': 276,\n", 721 | " 'advance': 277,\n", 722 | " 'appreciated': 278,\n", 723 | " 'pass': 279,\n", 724 | " 'creating': 280,\n", 725 | " 'g': 281,\n", 726 | " 'site': 282,\n", 727 | " 'node': 283,\n", 728 | " 'implement': 284,\n", 729 | " 'contains': 285,\n", 730 | " 'let': 286,\n", 731 | " 'control': 287,\n", 732 | " 'classes': 288,\n", 733 | " 'module': 289,\n", 734 | " 'ios': 290,\n", 735 | " 'thing': 291,\n", 736 | " 'there': 292,\n", 737 | " 'package': 293,\n", 738 | " 'send': 294,\n", 739 | " 'errors': 295,\n", 740 | " 'we': 296,\n", 741 | " 'actually': 297,\n", 742 | " 'log': 298,\n", 743 | " 'index': 299,\n", 744 | " 'uses': 300,\n", 745 | " 'last': 301,\n", 746 | " 'asp': 302,\n", 747 | " 'b': 303,\n", 748 | " 'changes': 304,\n", 749 | " 'going': 305,\n", 750 | " 'lot': 306,\n", 751 | " 'local': 307,\n", 752 | " 'to': 308,\n", 753 | " 'format': 309,\n", 754 | " 'install': 310,\n", 755 | " 'option': 311,\n", 756 | " 'remove': 312,\n", 757 | " 'or': 313,\n", 758 | " 'item': 314,\n", 759 | " 'map': 315,\n", 760 | " 'adding': 316,\n", 761 | " 'must': 317,\n", 762 | " 'either': 318,\n", 763 | " 'screen': 319,\n", 764 | " 'installed': 320,\n", 765 | " 'figure': 321,\n", 766 | " 'things': 322,\n", 767 | " 'correctly': 323,\n", 768 | " 'elements': 324,\n", 769 | " 'parameter': 325,\n", 770 | " 'might': 326,\n", 771 | " 'take': 327,\n", 772 | " 'window': 328,\n", 773 | " 'convert': 329,\n", 774 | " 'images': 330,\n", 775 | " 'template': 331,\n", 776 | " 'css': 332,\n", 777 | " 'background': 333,\n", 778 | " 'com': 334,\n", 779 | " 'ideas': 335,\n", 780 | " 'give': 336,\n", 781 | " 'note': 337,\n", 782 | " 'console': 338,\n", 783 | " 'connection': 339,\n", 784 | " 'setting': 340,\n", 785 | " 'made': 341,\n", 786 | " 'defined': 342,\n", 787 | " 'keep': 343,\n", 788 | " 'compile': 344,\n", 789 | " 'shows': 345,\n", 790 | " 'rows': 346,\n", 791 | " 'response': 347,\n", 792 | " 'nothing': 348,\n", 793 | " 'thank': 349,\n", 794 | " 'spring': 350,\n", 795 | " 'expected': 351,\n", 796 | " 'fix': 352,\n", 797 | " 'null': 353,\n", 798 | " 'save': 354,\n", 799 | " 'gets': 355,\n", 800 | " 'documentation': 356,\n", 801 | " 'plugin': 357,\n", 802 | " 'config': 358,\n", 803 | " 'columns': 359,\n", 804 | " 'properties': 360,\n", 805 | " 'structure': 361,\n", 806 | " 'reason': 362,\n", 807 | " 'side': 363,\n", 808 | " 'after': 364,\n", 809 | " 'items': 365,\n", 810 | " 'header': 366,\n", 811 | " 'long': 367,\n", 812 | " 'task': 368,\n", 813 | " 'next': 369,\n", 814 | " 'variables': 370,\n", 815 | " 'though': 371,\n", 816 | " 'difference': 372,\n", 817 | " 'approach': 373,\n", 818 | " 'writing': 374,\n", 819 | " 'gives': 375,\n", 820 | " 'tests': 376,\n", 821 | " 'non': 377,\n", 822 | " 'website': 378,\n", 823 | " 'login': 379,\n", 824 | " 'action': 380,\n", 825 | " 'support': 381,\n", 826 | " 'chrome': 382,\n", 827 | " 'then': 383,\n", 828 | " 'lines': 384,\n", 829 | " 'configuration': 385,\n", 830 | " 'available': 386,\n", 831 | " 'several': 387,\n", 832 | " 'fields': 388,\n", 833 | " 'tell': 389,\n", 834 | " 'handle': 390,\n", 835 | " 'include': 391,\n", 836 | " 'parameters': 392,\n", 837 | " 'static': 393,\n", 838 | " 'maybe': 394,\n", 839 | " 'times': 395,\n", 840 | " 'interface': 396,\n", 841 | " 'stored': 397,\n", 842 | " 'mvc': 398,\n", 843 | " 'entity': 399,\n", 844 | " 'machine': 400,\n", 845 | " 'device': 401,\n", 846 | " 'calls': 402,\n", 847 | " 'mysql': 403,\n", 848 | " 'changed': 404,\n", 849 | " 'needs': 405,\n", 850 | " 'component': 406,\n", 851 | " 'testing': 407,\n", 852 | " 'statement': 408,\n", 853 | " 'achieve': 409,\n", 854 | " 'solve': 410,\n", 855 | " 'top': 411,\n", 856 | " 'empty': 412,\n", 857 | " 'else': 413,\n", 858 | " 'implementation': 414,\n", 859 | " 'generate': 415,\n", 860 | " 'types': 416,\n", 861 | " 'state': 417,\n", 862 | " 'thought': 418,\n", 863 | " 'http': 419,\n", 864 | " 'copy': 420,\n", 865 | " 'calling': 421,\n", 866 | " 'questions': 422,\n", 867 | " 'ui': 423,\n", 868 | " 'db': 424,\n", 869 | " 'setup': 425,\n", 870 | " 'standard': 426,\n", 871 | " 'session': 427,\n", 872 | " 'vs': 428,\n", 873 | " 'wondering': 429,\n", 874 | " 'rails': 430,\n", 875 | " 'problems': 431,\n", 876 | " 'generated': 432,\n", 877 | " 'simply': 433,\n", 878 | " 'compiler': 434,\n", 879 | " 'rather': 435,\n", 880 | " 'tables': 436,\n", 881 | " 'address': 437,\n", 882 | " 'all': 438,\n", 883 | " 'linux': 439,\n", 884 | " 'allow': 440,\n", 885 | " 'attribute': 441,\n", 886 | " 'core': 442,\n", 887 | " 'basically': 443,\n", 888 | " 'reading': 444,\n", 889 | " 'django': 445,\n", 890 | " 'written': 446,\n", 891 | " 'print': 447,\n", 892 | " 'stack': 448,\n", 893 | " 'large': 449,\n", 894 | " 'box': 450,\n", 895 | " 'never': 451,\n", 896 | " 'execute': 452,\n", 897 | " 'kind': 453,\n", 898 | " 'eclipse': 454,\n", 899 | " 'debug': 455,\n", 900 | " 'fails': 456,\n", 901 | " 'insert': 457,\n", 902 | " 'document': 458,\n", 903 | " 'on': 459,\n", 904 | " 'properly': 460,\n", 905 | " 'environment': 461,\n", 906 | " 'mean': 462,\n", 907 | " 'level': 463,\n", 908 | " 'mode': 464,\n", 909 | " 'color': 465,\n", 910 | " 'delete': 466,\n", 911 | " 'tag': 467,\n", 912 | " 'takes': 468,\n", 913 | " 'password': 469,\n", 914 | " 'happens': 470,\n", 915 | " 'making': 471,\n", 916 | " 'sort': 472,\n", 917 | " 'bar': 473,\n", 918 | " 'collection': 474,\n", 919 | " 'full': 475,\n", 920 | " 'certain': 476,\n", 921 | " 'says': 477,\n", 922 | " 'location': 478,\n", 923 | " 'performance': 479,\n", 924 | " 'language': 480,\n", 925 | " 'related': 481,\n", 926 | " 'rest': 482,\n", 927 | " 'layout': 483,\n", 928 | " 'follows': 484,\n", 929 | " 'domain': 485,\n", 930 | " 'place': 486,\n", 931 | " 'r': 487,\n", 932 | " 'sample': 488,\n", 933 | " 'do': 489,\n", 934 | " 'automatically': 490,\n", 935 | " 'quite': 491,\n", 936 | " 'options': 492,\n", 937 | " 'parent': 493,\n", 938 | " 'context': 494,\n", 939 | " 'connect': 495,\n", 940 | " 'directly': 496,\n", 941 | " 'video': 497,\n", 942 | " 'whether': 498,\n", 943 | " 'runs': 499,\n", 944 | " 'filter': 500,\n", 945 | " 'true': 501,\n", 946 | " 'left': 502,\n", 947 | " 'menu': 503,\n", 948 | " 'numbers': 504,\n", 949 | " 'whole': 505,\n", 950 | " 'makes': 506,\n", 951 | " 'that': 507,\n", 952 | " 'n': 508,\n", 953 | " 'activity': 509,\n", 954 | " 'per': 510,\n", 955 | " 'group': 511,\n", 956 | " 'exactly': 512,\n", 957 | " 'started': 513,\n", 958 | " 'manually': 514,\n", 959 | " 'dll': 515,\n", 960 | " 'settings': 516,\n", 961 | " 'pattern': 517,\n", 962 | " 'play': 518,\n", 963 | " 'ok': 519,\n", 964 | " 'upload': 520,\n", 965 | " 'behavior': 521,\n", 966 | " 'child': 522,\n", 967 | " 'separate': 523,\n", 968 | " 'failed': 524,\n", 969 | " 'characters': 525,\n", 970 | " 'suggestions': 526,\n", 971 | " 'little': 527,\n", 972 | " 'not': 528,\n", 973 | " 'appears': 529,\n", 974 | " 'existing': 530,\n", 975 | " 'avoid': 531,\n", 976 | " 'basic': 532,\n", 977 | " 'strings': 533,\n", 978 | " 'download': 534,\n", 979 | " 'container': 535,\n", 980 | " 'tab': 536,\n", 981 | " 'pages': 537,\n", 982 | " '9': 538,\n", 983 | " 'seen': 539,\n", 984 | " 'pretty': 540,\n", 985 | " 'small': 541,\n", 986 | " 'provide': 542,\n", 987 | " 'email': 543,\n", 988 | " 'ajax': 544,\n", 989 | " 'issues': 545,\n", 990 | " 'original': 546,\n", 991 | " 'requests': 547,\n", 992 | " 'means': 548,\n", 993 | " 'selected': 549,\n", 994 | " 'import': 550,\n", 995 | " 'syntax': 551,\n", 996 | " 'block': 552,\n", 997 | " 'names': 553,\n", 998 | " 'yet': 554,\n", 999 | " 'repository': 555,\n", 1000 | " 'updated': 556,\n", 1001 | " 'changing': 557,\n", 1002 | " 'unit': 558,\n", 1003 | " 'come': 559,\n", 1004 | " 'match': 560,\n", 1005 | " 'easy': 561,\n", 1006 | " 'xcode': 562,\n", 1007 | " 'projects': 563,\n", 1008 | " 'replace': 564,\n", 1009 | " 'particular': 565,\n", 1010 | " 'examples': 566,\n", 1011 | " 'constructor': 567,\n", 1012 | " 'space': 568,\n", 1013 | " 'info': 569,\n", 1014 | " 'cell': 570,\n", 1015 | " 'built': 571,\n", 1016 | " 'actual': 572,\n", 1017 | " 'unable': 573,\n", 1018 | " 'shown': 574,\n", 1019 | " 'required': 575,\n", 1020 | " 'jar': 576,\n", 1021 | " 'root': 577,\n", 1022 | " 'explain': 578,\n", 1023 | " 'events': 579,\n", 1024 | " 'move': 580,\n", 1025 | " 'worked': 581,\n", 1026 | " 'old': 582,\n", 1027 | " 'no': 583,\n", 1028 | " 'three': 584,\n", 1029 | " 'git': 585,\n", 1030 | " 'define': 586,\n", 1031 | " 'loaded': 587,\n", 1032 | " 'step': 588,\n", 1033 | " 'count': 589,\n", 1034 | " 'messages': 590,\n", 1035 | " 'great': 591,\n", 1036 | " 'real': 592,\n", 1037 | " 'building': 593,\n", 1038 | " 'loading': 594,\n", 1039 | " 'authentication': 595,\n", 1040 | " 'base': 596,\n", 1041 | " 'network': 597,\n", 1042 | " 'account': 598,\n", 1043 | " 'character': 599,\n", 1044 | " 'views': 600,\n", 1045 | " 'iphone': 601,\n", 1046 | " 'clear': 602,\n", 1047 | " 'development': 603,\n", 1048 | " 'points': 604,\n", 1049 | " 'token': 605,\n", 1050 | " 'position': 606,\n", 1051 | " 'argument': 607,\n", 1052 | " 'dynamic': 608,\n", 1053 | " 'checked': 609,\n", 1054 | " 'common': 610,\n", 1055 | " 'ie': 611,\n", 1056 | " 'libraries': 612,\n", 1057 | " 'extension': 613,\n", 1058 | " 'os': 614,\n", 1059 | " 'returned': 615,\n", 1060 | " 'design': 616,\n", 1061 | " 'angular': 617,\n", 1062 | " 'firefox': 618,\n", 1063 | " 'apache': 619,\n", 1064 | " 'from': 620,\n", 1065 | " 'which': 621,\n", 1066 | " 'style': 622,\n", 1067 | " 'cache': 623,\n", 1068 | " 'security': 624,\n", 1069 | " 'native': 625,\n", 1070 | " 'stream': 626,\n", 1071 | " 'record': 627,\n", 1072 | " 'successfully': 628,\n", 1073 | " 'width': 629,\n", 1074 | " 'word': 630,\n", 1075 | " 'showing': 631,\n", 1076 | " 'with': 632,\n", 1077 | " 'facebook': 633,\n", 1078 | " 'resource': 634,\n", 1079 | " 'parse': 635,\n", 1080 | " 'exist': 636,\n", 1081 | " 'where': 637,\n", 1082 | " 'records': 638,\n", 1083 | " 'services': 639,\n", 1084 | " 'contain': 640,\n", 1085 | " 'course': 641,\n", 1086 | " 'redirect': 642,\n", 1087 | " 'needed': 643,\n", 1088 | " 'according': 644,\n", 1089 | " 'except': 645,\n", 1090 | " 'ruby': 646,\n", 1091 | " 'mobile': 647,\n", 1092 | " 'are': 648,\n", 1093 | " 'target': 649,\n", 1094 | " 'threads': 650,\n", 1095 | " 'enough': 651,\n", 1096 | " 'least': 652,\n", 1097 | " 'feature': 653,\n", 1098 | " 'expression': 654,\n", 1099 | " '11': 655,\n", 1100 | " 'comes': 656,\n", 1101 | " 'previous': 657,\n", 1102 | " 'nested': 658,\n", 1103 | " 'validation': 659,\n", 1104 | " 'guess': 660,\n", 1105 | " 'appear': 661,\n", 1106 | " '100': 662,\n", 1107 | " 'sdk': 663,\n", 1108 | " 'job': 664,\n", 1109 | " 'dialog': 665,\n", 1110 | " 'scroll': 666,\n", 1111 | " 'pointer': 667,\n", 1112 | " 'across': 668,\n", 1113 | " 'height': 669,\n", 1114 | " 'int': 670,\n", 1115 | " 'people': 671,\n", 1116 | " 'apps': 672,\n", 1117 | " 'probably': 673,\n", 1118 | " 'close': 674,\n", 1119 | " 'at': 675,\n", 1120 | " 'product': 676,\n", 1121 | " 'vector': 677,\n", 1122 | " 'valid': 678,\n", 1123 | " 'functionality': 679,\n", 1124 | " 'tree': 680,\n", 1125 | " 'maven': 681,\n", 1126 | " 'sometimes': 682,\n", 1127 | " 'seconds': 683,\n", 1128 | " 'div': 684,\n", 1129 | " 'wanted': 685,\n", 1130 | " 'binary': 686,\n", 1131 | " 'displayed': 687,\n", 1132 | " 'external': 688,\n", 1133 | " 'runtime': 689,\n", 1134 | " 'stop': 690,\n", 1135 | " 'bug': 691,\n", 1136 | " 'dependencies': 692,\n", 1137 | " 'implemented': 693,\n", 1138 | " 'you': 694,\n", 1139 | " 'warning': 695,\n", 1140 | " 'arguments': 696,\n", 1141 | " 'complete': 697,\n", 1142 | " 'programming': 698,\n", 1143 | " 'dynamically': 699,\n", 1144 | " 'goes': 700,\n", 1145 | " 'keys': 701,\n", 1146 | " 'random': 702,\n", 1147 | " 'section': 703,\n", 1148 | " 'perform': 704,\n", 1149 | " 'frame': 705,\n", 1150 | " 'microsoft': 706,\n", 1151 | " 'passed': 707,\n", 1152 | " 'integer': 708,\n", 1153 | " 'normal': 709,\n", 1154 | " 'specify': 710,\n", 1155 | " 'tool': 711,\n", 1156 | " 'route': 712,\n", 1157 | " 'remote': 713,\n", 1158 | " 'somehow': 714,\n", 1159 | " 'cause': 715,\n", 1160 | " 'due': 716,\n", 1161 | " 'apply': 717,\n", 1162 | " 'solutions': 718,\n", 1163 | " 'containing': 719,\n", 1164 | " 'links': 720,\n", 1165 | " 'less': 721,\n", 1166 | " 'none': 722,\n", 1167 | " 'bad': 723,\n", 1168 | " 'status': 724,\n", 1169 | " 'cases': 725,\n", 1170 | " 'matrix': 726,\n", 1171 | " 'passing': 727,\n", 1172 | " 'modules': 728,\n", 1173 | " 'algorithm': 729,\n", 1174 | " 'docker': 730,\n", 1175 | " 'receive': 731,\n", 1176 | " 'am': 732,\n", 1177 | " 'phone': 733,\n", 1178 | " 'unfortunately': 734,\n", 1179 | " 'trouble': 735,\n", 1180 | " 'versions': 736,\n", 1181 | " 'binding': 737,\n", 1182 | " 'public': 738,\n", 1183 | " 'details': 739,\n", 1184 | " 'resources': 740,\n", 1185 | " 'later': 741,\n", 1186 | " 'thinking': 742,\n", 1187 | " 'should': 743,\n", 1188 | " 'ways': 744,\n", 1189 | " 'react': 745,\n", 1190 | " 'words': 746,\n", 1191 | " 'exists': 747,\n", 1192 | " 'starting': 748,\n", 1193 | " 'tools': 749,\n", 1194 | " 'socket': 750,\n", 1195 | " 'pdf': 751,\n", 1196 | " 'port': 752,\n", 1197 | " 'big': 753,\n", 1198 | " 'models': 754,\n", 1199 | " 'entire': 755,\n", 1200 | " 'applications': 756,\n", 1201 | " 'logic': 757,\n", 1202 | " 'operation': 758,\n", 1203 | " 'looked': 759,\n", 1204 | " 'answers': 760,\n", 1205 | " 'title': 761,\n", 1206 | " 'admin': 762,\n", 1207 | " 'mac': 763,\n", 1208 | " 'dependency': 764,\n", 1209 | " 'queries': 765,\n", 1210 | " 'assume': 766,\n", 1211 | " 'graph': 767,\n", 1212 | " 'swift': 768,\n", 1213 | " 'generic': 769,\n", 1214 | " 'switch': 770,\n", 1215 | " 'length': 771,\n", 1216 | " 'happening': 772,\n", 1217 | " 'understanding': 773,\n", 1218 | " 'day': 774,\n", 1219 | " 'retrieve': 775,\n", 1220 | " 'csv': 776,\n", 1221 | " 'game': 777,\n", 1222 | " 'named': 778,\n", 1223 | " 'buttons': 779,\n", 1224 | " 'laravel': 780,\n", 1225 | " 'excel': 781,\n", 1226 | " 'obviously': 782,\n", 1227 | " 'grid': 783,\n", 1228 | " 'tags': 784,\n", 1229 | " 'unique': 785,\n", 1230 | " 'host': 786,\n", 1231 | " 'range': 787,\n", 1232 | " 'release': 788,\n", 1233 | " '64': 789,\n", 1234 | " 'require': 790,\n", 1235 | " 'specified': 791,\n", 1236 | " 'developing': 792,\n", 1237 | " 'global': 793,\n", 1238 | " 'p': 794,\n", 1239 | " 'assembly': 795,\n", 1240 | " 'shared': 796,\n", 1241 | " 'happen': 797,\n", 1242 | " 'modify': 798,\n", 1243 | " 'stuff': 799,\n", 1244 | " 'including': 800,\n", 1245 | " 'hard': 801,\n", 1246 | " 'drop': 802,\n", 1247 | " 'additional': 803,\n", 1248 | " 'auto': 804,\n", 1249 | " 'operator': 805,\n", 1250 | " '12': 806,\n", 1251 | " 'returning': 807,\n", 1252 | " 'false': 808,\n", 1253 | " 'tutorial': 809,\n", 1254 | " 'duplicate': 810,\n", 1255 | " 'sent': 811,\n", 1256 | " 'forms': 812,\n", 1257 | " 'contents': 813,\n", 1258 | " 'queue': 814,\n", 1259 | " 'expect': 815,\n", 1260 | " 'org': 816,\n", 1261 | " 'enter': 817,\n", 1262 | " 'situation': 818,\n", 1263 | " 'completely': 819,\n", 1264 | " 'exe': 820,\n", 1265 | " 'home': 821,\n", 1266 | " 'anybody': 822,\n", 1267 | " 'packages': 823,\n", 1268 | " 'push': 824,\n", 1269 | " 'attributes': 825,\n", 1270 | " 'layer': 826,\n", 1271 | " 'trigger': 827,\n", 1272 | " 'various': 828,\n", 1273 | " 'executed': 829,\n", 1274 | " 'shell': 830,\n", 1275 | " 'prevent': 831,\n", 1276 | " 'gcc': 832,\n", 1277 | " 'bottom': 833,\n", 1278 | " 'scala': 834,\n", 1279 | " 'gradle': 835,\n", 1280 | " 'double': 836,\n", 1281 | " 'procedure': 837,\n", 1282 | " 'handler': 838,\n", 1283 | " 'requires': 839,\n", 1284 | " 'schema': 840,\n", 1285 | " 'fact': 841,\n", 1286 | " 'certificate': 842,\n", 1287 | " 'effect': 843,\n", 1288 | " 'nodes': 844,\n", 1289 | " 'included': 845,\n", 1290 | " 'lib': 846,\n", 1291 | " 'fixed': 847,\n", 1292 | " 'report': 848,\n", 1293 | " '20': 849,\n", 1294 | " 'resolve': 850,\n", 1295 | " 'sub': 851,\n", 1296 | " 'lists': 852,\n", 1297 | " 'proxy': 853,\n", 1298 | " 'processing': 854,\n", 1299 | " 'internet': 855,\n", 1300 | " 'wpf': 856,\n", 1301 | " 'execution': 857,\n", 1302 | " 'outside': 858,\n", 1303 | " 'together': 859,\n", 1304 | " 'supposed': 860,\n", 1305 | " 'scope': 861,\n", 1306 | " 'invalid': 862,\n", 1307 | " 'commands': 863,\n", 1308 | " 'im': 864,\n", 1309 | " 'recently': 865,\n", 1310 | " 'ask': 866,\n", 1311 | " 'private': 867,\n", 1312 | " 'icon': 868,\n", 1313 | " 'render': 869,\n", 1314 | " 'stuck': 870,\n", 1315 | " 'references': 871,\n", 1316 | " 'sense': 872,\n", 1317 | " 'bootstrap': 873,\n", 1318 | " 'oracle': 874,\n", 1319 | " 'provided': 875,\n", 1320 | " 'latest': 876,\n", 1321 | " 'py': 877,\n", 1322 | " 'sending': 878,\n", 1323 | " 'some': 879,\n", 1324 | " 'perfectly': 880,\n", 1325 | " 'configure': 881,\n", 1326 | " 'spark': 882,\n", 1327 | " 'headers': 883,\n", 1328 | " 'proper': 884,\n", 1329 | " 'ubuntu': 885,\n", 1330 | " 'callback': 886,\n", 1331 | " 'general': 887,\n", 1332 | " 'suppose': 888,\n", 1333 | " 'hash': 889,\n", 1334 | " 'ip': 890,\n", 1335 | " 'member': 891,\n", 1336 | " 'along': 892,\n", 1337 | " 'believe': 893,\n", 1338 | " 'linq': 894,\n", 1339 | " 'developer': 895,\n", 1340 | " 'detect': 896,\n", 1341 | " 'font': 897,\n", 1342 | " 'disable': 898,\n", 1343 | " 'logs': 899,\n", 1344 | " 'searching': 900,\n", 1345 | " 'definition': 901,\n", 1346 | " 'strange': 902,\n", 1347 | " 'greatly': 903,\n", 1348 | " 'present': 904,\n", 1349 | " 'creates': 905,\n", 1350 | " 'scenario': 906,\n", 1351 | " 'deploy': 907,\n", 1352 | " 'h': 908,\n", 1353 | " 'starts': 909,\n", 1354 | " 'debugging': 910,\n", 1355 | " 'force': 911,\n", 1356 | " 'plot': 912,\n", 1357 | " 'iis': 913,\n", 1358 | " 'slow': 914,\n", 1359 | " 'days': 915,\n", 1360 | " 'although': 916,\n", 1361 | " 'noticed': 917,\n", 1362 | " 'relevant': 918,\n", 1363 | " 'practice': 919,\n", 1364 | " 'bytes': 920,\n", 1365 | " 'amount': 921,\n", 1366 | " 'navigation': 922,\n", 1367 | " 'entry': 923,\n", 1368 | " 'hope': 924,\n", 1369 | " 'matter': 925,\n", 1370 | " 'extract': 926,\n", 1371 | " 'managed': 927,\n", 1372 | " 'whenever': 928,\n", 1373 | " 'allows': 929,\n", 1374 | " 'compiled': 930,\n", 1375 | " 'usage': 931,\n", 1376 | " 'throws': 932,\n", 1377 | " 'arrays': 933,\n", 1378 | " 'learning': 934,\n", 1379 | " 'appreciate': 935,\n", 1380 | " 'attempt': 936,\n", 1381 | " 'join': 937,\n", 1382 | " 'f': 938,\n", 1383 | " 'lambda': 939,\n", 1384 | " 'share': 940,\n", 1385 | " 'suggest': 941,\n", 1386 | " 'instances': 942,\n", 1387 | " 'wrote': 943,\n", 1388 | " 'split': 944,\n", 1389 | " 'platform': 945,\n", 1390 | " 'active': 946,\n", 1391 | " 'keyboard': 947,\n", 1392 | " 'label': 948,\n", 1393 | " 'others': 949,\n", 1394 | " 'kernel': 950,\n", 1395 | " 'extra': 951,\n", 1396 | " 'operations': 952,\n", 1397 | " 'nice': 953,\n", 1398 | " 'third': 954,\n", 1399 | " 'animation': 955,\n", 1400 | " 'fragment': 956,\n", 1401 | " 'determine': 957,\n", 1402 | " 'components': 958,\n", 1403 | " 'yes': 959,\n", 1404 | " 'equivalent': 960,\n", 1405 | " 'came': 961,\n", 1406 | " 'docs': 962,\n", 1407 | " 'bind': 963,\n", 1408 | " 'confused': 964,\n", 1409 | " 'below': 965,\n", 1410 | " 'each': 966,\n", 1411 | " 'free': 967,\n", 1412 | " 'implementing': 968,\n", 1413 | " 'fail': 969,\n", 1414 | " 'efficient': 970,\n", 1415 | " 'body': 971,\n", 1416 | " 'computer': 972,\n", 1417 | " 'removed': 973,\n", 1418 | " 'undefined': 974,\n", 1419 | " 'behind': 975,\n", 1420 | " 'internal': 976,\n", 1421 | " 'github': 977,\n", 1422 | " 'engine': 978,\n", 1423 | " 'easily': 979,\n", 1424 | " 'merge': 980,\n", 1425 | " 'by': 981,\n", 1426 | " 'dictionary': 982,\n", 1427 | " 'consider': 983,\n", 1428 | " 'tasks': 984,\n", 1429 | " 'specifically': 985,\n", 1430 | " 'dataframe': 986,\n", 1431 | " 'listview': 987,\n", 1432 | " 'ssl': 988,\n", 1433 | " 'steps': 989,\n", 1434 | " 'while': 990,\n", 1435 | " 'checking': 991,\n", 1436 | " 'enable': 992,\n", 1437 | " 'behaviour': 993,\n", 1438 | " 'longer': 994,\n", 1439 | " 'production': 995,\n", 1440 | " 'success': 996,\n", 1441 | " 'anyway': 997,\n", 1442 | " 'camera': 998,\n", 1443 | " 'comment': 999,\n", 1444 | " 'entities': 1000,\n", 1445 | " ...})" 1446 | ] 1447 | }, 1448 | "execution_count": 15, 1449 | "metadata": {}, 1450 | "output_type": "execute_result" 1451 | } 1452 | ], 1453 | "source": [ 1454 | "# Create embeddings matrix\n", 1455 | "all_embs = np.stack(embeddings_index.values())\n", 1456 | "emb_mean, emb_std = all_embs.mean(), all_embs.std()\n", 1457 | "\n", 1458 | "# Create embedding matrix using our vocabulary\n", 1459 | "word_index = tokenizer.word_index\n", 1460 | "nb_words = min(TOKEN_COUNT, len(word_index))\n", 1461 | "print(nb_words)\n", 1462 | "\n", 1463 | "# Random normal for missing entries\n", 1464 | "# embedding_matrix = np.random.normal(emb_mean, emb_std, (nb_words, EMBED_SIZE))\n", 1465 | "# for word, i in word_index.items():\n", 1466 | "# embedding_vector = embeddings_index.get(word)\n", 1467 | "# if embedding_vector is not None: \n", 1468 | "# embedding_matrix[i] = embedding_vector\n", 1469 | "\n", 1470 | "# Zero for missing entries\n", 1471 | "embedding_matrix = np.zeros((nb_words, EMBED_SIZE))\n", 1472 | "for word, i in word_index.items():\n", 1473 | " embedding_vector = embeddings_index.get(word)\n", 1474 | " if embedding_vector is not None:\n", 1475 | " embedding_matrix[i] = embedding_vector\n", 1476 | "\n", 1477 | "embedding_matrix.shape, word_index" 1478 | ] 1479 | }, 1480 | { 1481 | "cell_type": "markdown", 1482 | "metadata": {}, 1483 | "source": [ 1484 | "## Load Elmo Embedding layer\n", 1485 | "\n", 1486 | "Here we load the Elmo embedding layer using Tensorflow Hub." 1487 | ] 1488 | }, 1489 | { 1490 | "cell_type": "code", 1491 | "execution_count": 16, 1492 | "metadata": {}, 1493 | "outputs": [ 1494 | { 1495 | "name": "stdout", 1496 | "output_type": "stream", 1497 | "text": [ 1498 | "WARNING:tensorflow:From /home/rjurney/anaconda3/envs/weak/lib/python3.7/site-packages/tensorflow/python/ops/control_flow_ops.py:3632: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.\n", 1499 | "Instructions for updating:\n", 1500 | "Colocations handled automatically by placer.\n" 1501 | ] 1502 | }, 1503 | { 1504 | "name": "stderr", 1505 | "output_type": "stream", 1506 | "text": [ 1507 | "WARNING:tensorflow:From /home/rjurney/anaconda3/envs/weak/lib/python3.7/site-packages/tensorflow/python/ops/control_flow_ops.py:3632: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.\n", 1508 | "Instructions for updating:\n", 1509 | "Colocations handled automatically by placer.\n" 1510 | ] 1511 | }, 1512 | { 1513 | "name": "stdout", 1514 | "output_type": "stream", 1515 | "text": [ 1516 | "INFO:tensorflow:Saver not created because there are no variables in the graph to restore\n" 1517 | ] 1518 | }, 1519 | { 1520 | "name": "stderr", 1521 | "output_type": "stream", 1522 | "text": [ 1523 | "INFO:tensorflow:Saver not created because there are no variables in the graph to restore\n" 1524 | ] 1525 | } 1526 | ], 1527 | "source": [ 1528 | "elmo = hub.Module(\"https://tfhub.dev/google/elmo/2\", trainable=True)\n", 1529 | "embeddings = elmo(\n", 1530 | " [\"the cat is on the mat\", \"dogs are in the fog\"],\n", 1531 | " signature=\"default\",\n", 1532 | " as_dict=True)[\"elmo\"]" 1533 | ] 1534 | }, 1535 | { 1536 | "cell_type": "code", 1537 | "execution_count": null, 1538 | "metadata": {}, 1539 | "outputs": [], 1540 | "source": [ 1541 | "from sklearn.model_selection import train_test_split\n", 1542 | "\n", 1543 | "X_train, X_test, y_train, y_test = train_test_split(\n", 1544 | " padded_sequences,\n", 1545 | " labels,\n", 1546 | " test_size=TEST_SPLIT,\n", 1547 | " random_state=1337\n", 1548 | ")\n", 1549 | "\n", 1550 | "# Conserve RAM\n", 1551 | "del padded_sequences\n", 1552 | "del labels\n", 1553 | "gc.collect()\n", 1554 | "\n", 1555 | "assert(X_train.shape[0] == y_train.shape[0])\n", 1556 | "assert(X_train.shape[1] == MAX_LEN)\n", 1557 | "assert(X_test.shape[0] == y_test.shape[0]) \n", 1558 | "assert(X_test.shape[1] == MAX_LEN)" 1559 | ] 1560 | }, 1561 | { 1562 | "cell_type": "code", 1563 | "execution_count": null, 1564 | "metadata": {}, 1565 | "outputs": [], 1566 | "source": [ 1567 | "train_weight_vec = list(np.max(np.sum(y_train, axis=0)) / np.sum(y_train, axis=0))\n", 1568 | "train_class_weights = {i: train_weight_vec[i] for i in range(y_train.shape[1])}\n", 1569 | "\n", 1570 | "test_weight_vec = list(np.max(np.sum(y_test, axis=0)) / np.sum(y_test, axis=0))\n", 1571 | "test_class_weights = {i: test_weight_vec[i] for i in range(y_test.shape[1])}\n", 1572 | "\n", 1573 | "sorted(list(train_class_weights.items()), key=lambda x: x[1]), sorted(list(test_class_weights.items()), key=lambda x: x[1])" 1574 | ] 1575 | }, 1576 | { 1577 | "cell_type": "markdown", 1578 | "metadata": {}, 1579 | "source": [ 1580 | "## Create a Performance Log for the Model\n", 1581 | "\n", 1582 | "We will log the original performance as a reference point as well as the performance of the latest model to the current run." 1583 | ] 1584 | }, 1585 | { 1586 | "cell_type": "code", 1587 | "execution_count": 17, 1588 | "metadata": {}, 1589 | "outputs": [ 1590 | { 1591 | "data": { 1592 | "text/plain": [ 1593 | "0" 1594 | ] 1595 | }, 1596 | "execution_count": 17, 1597 | "metadata": {}, 1598 | "output_type": "execute_result" 1599 | } 1600 | ], 1601 | "source": [ 1602 | "try:\n", 1603 | " simple_log\n", 1604 | "except NameError:\n", 1605 | " simple_log = []\n", 1606 | "\n", 1607 | "try:\n", 1608 | " with open('simple_log.jsonl') as f:\n", 1609 | " for line in f:\n", 1610 | " simple_log.append(json.loads(line))\n", 1611 | "except FileNotFoundError:\n", 1612 | " pass\n", 1613 | "\n", 1614 | "SEQUENCE = simple_log[-1]['sequence'] if len(simple_log) > 0 else 0\n", 1615 | "\n", 1616 | "SEQUENCE" 1617 | ] 1618 | }, 1619 | { 1620 | "cell_type": "markdown", 1621 | "metadata": {}, 1622 | "source": [ 1623 | "## Try a Simple CNN Model to Classify Questions to their Corresponding Tags\n", 1624 | "\n", 1625 | "Now we’re ready to train a model to classify/label questions with tag categories. We start with a simple model with one `Conv1D`/`GlobalMaxPool1D`. We use the functional API and we’ve heavily parametrized the code so as to facilitate experimentation." 1626 | ] 1627 | }, 1628 | { 1629 | "cell_type": "code", 1630 | "execution_count": null, 1631 | "metadata": {}, 1632 | "outputs": [], 1633 | "source": [ 1634 | "from tensorflow.keras.initializers import RandomUniform\n", 1635 | "from tensorflow.keras.models import Sequential\n", 1636 | "from tensorflow.keras.layers import Dense, Embedding, Flatten, GlobalMaxPool1D, Dropout, Conv1D\n", 1637 | "from tensorflow.keras.callbacks import ReduceLROnPlateau, EarlyStopping, ModelCheckpoint\n", 1638 | "from tensorflow.keras.losses import binary_crossentropy, kld\n", 1639 | "from tensorflow.keras.optimizers import Adam\n", 1640 | "\n", 1641 | "import lib.utils\n", 1642 | "\n", 1643 | "\n", 1644 | "FILTER_COUNT = 128\n", 1645 | "FILTER_SIZE = 3\n", 1646 | "EPOCHS = 8\n", 1647 | "ACTIVATION = 'selu'\n", 1648 | "CONV_PADDING = 'same'\n", 1649 | "STRIDES = 1\n", 1650 | "EMBED_SIZE = 300\n", 1651 | "EMBED_DROPOUT_RATIO = 0.1\n", 1652 | "CONV_DROPOUT_RATIO = 0.1\n", 1653 | "\n", 1654 | "EXPERIMENT_NAME = 'simple_cnn_again'\n", 1655 | "\n", 1656 | "if len(simple_log) > 0 and EXPERIMENT_NAME == simple_log[-1]['name']:\n", 1657 | " print('RENAME YOUR EXPERIMENT')\n", 1658 | " raise Exception('RENAME YOUR EXPERIMENT')\n", 1659 | "\n", 1660 | "SEQUENCE += 1\n", 1661 | "\n", 1662 | "\n", 1663 | "# Weights and Biases Monitoring\n", 1664 | "# import wandb\n", 1665 | "# from wandb.keras import WandbCallback\n", 1666 | "# wandb.init(project=\"weakly-supervised-learning\", name=EXPERIMENT_NAME)\n", 1667 | "# config = wandb.config\n", 1668 | "\n", 1669 | "# config_dict = {\n", 1670 | "# 'name': EXPERIMENT_NAME,\n", 1671 | "# 'embedding': 'own',\n", 1672 | "# 'architecture': 'Simple Conv1D',\n", 1673 | "# 'epochs': EPOCHS,\n", 1674 | "# 'batch_size': BATCH_SIZE,\n", 1675 | "# 'filter_count': FILTER_COUNT,\n", 1676 | "# 'filter_size': FILTER_SIZE,\n", 1677 | "# 'activation': ACTIVATION,\n", 1678 | "# 'conv_padding': CONV_PADDING,\n", 1679 | "# 'sequence': SEQUENCE\n", 1680 | "# }\n", 1681 | "# print(config_dict)\n", 1682 | "# config.update(\n", 1683 | "# config_dict\n", 1684 | "# )\n", 1685 | "\n", 1686 | "titles = ['Own Embedding', 'Static GloVe', 'Retrained GloVe']\n", 1687 | "embedding_layers = [\n", 1688 | " \n", 1689 | " # Randomly Initialized Embedding\n", 1690 | " Embedding(\n", 1691 | " TOKEN_COUNT,\n", 1692 | " EMBED_SIZE, \n", 1693 | " input_length=X_train.shape[1],\n", 1694 | " embeddings_initializer=RandomUniform(),\n", 1695 | " ),\n", 1696 | " \n", 1697 | " # Static transfer of GloVe embedding\n", 1698 | " Embedding(\n", 1699 | " TOKEN_COUNT,\n", 1700 | " EMBED_SIZE,\n", 1701 | " weights=[embedding_matrix],\n", 1702 | " input_length=MAX_LEN,\n", 1703 | " trainable=False\n", 1704 | " ),\n", 1705 | " \n", 1706 | " # Retraining of GloVe embedding\n", 1707 | " Embedding(\n", 1708 | " TOKEN_COUNT,\n", 1709 | " EMBED_SIZE,\n", 1710 | " weights=[embedding_matrix],\n", 1711 | " input_length=MAX_LEN,\n", 1712 | " trainable=True\n", 1713 | " ),\n", 1714 | "]\n", 1715 | "\n", 1716 | "for title, emb_layer in zip(titles, embedding_layers):\n", 1717 | " model = Sequential()\n", 1718 | "\n", 1719 | " model.add(emb_layer)\n", 1720 | " model.add(Dropout(0.1))\n", 1721 | " model.add(\n", 1722 | " Conv1D(\n", 1723 | " FILTER_COUNT, \n", 1724 | " FILTER_SIZE, \n", 1725 | " padding=CONV_PADDING, \n", 1726 | " activation=ACTIVATION, \n", 1727 | " strides=1\n", 1728 | " )\n", 1729 | " )\n", 1730 | " model.add(GlobalMaxPool1D())\n", 1731 | " model.add(\n", 1732 | " Dense(\n", 1733 | " y_train.shape[1],\n", 1734 | " activation='sigmoid',\n", 1735 | " )\n", 1736 | " )\n", 1737 | "\n", 1738 | " model.compile(\n", 1739 | " optimizer='adam',\n", 1740 | " loss='binary_crossentropy',\n", 1741 | " metrics=[\n", 1742 | " tf.keras.metrics.CategoricalAccuracy(),\n", 1743 | " tf.keras.metrics.Precision(),\n", 1744 | " tf.keras.metrics.Recall(),\n", 1745 | " tf.keras.metrics.AUC(),\n", 1746 | " tf.keras.metrics.TruePositives(),\n", 1747 | " tf.keras.metrics.FalsePositives(),\n", 1748 | " tf.keras.metrics.TrueNegatives(),\n", 1749 | " tf.keras.metrics.FalseNegatives(),\n", 1750 | " ]\n", 1751 | " )\n", 1752 | " model.summary()\n", 1753 | "\n", 1754 | " callbacks = [\n", 1755 | " ReduceLROnPlateau(\n", 1756 | " monitor='val_categorical_accuracy',\n", 1757 | " factor=0.1,\n", 1758 | " patience=1,\n", 1759 | " verbose=1,\n", 1760 | " ), \n", 1761 | " EarlyStopping(\n", 1762 | " monitor='val_categorical_accuracy',\n", 1763 | " patience=2,\n", 1764 | " verbose=1,\n", 1765 | " ), \n", 1766 | " ModelCheckpoint(\n", 1767 | " filepath='models/cnn_tagger.weights.hdf5',\n", 1768 | " monitor='val_categorical_accuracy',\n", 1769 | " save_best_only=True,\n", 1770 | " verbose=1,\n", 1771 | " ),\n", 1772 | " # WandbCallback()\n", 1773 | " ]\n", 1774 | "\n", 1775 | " history = model.fit(X_train, y_train,\n", 1776 | " class_weight=train_class_weights,\n", 1777 | " epochs=EPOCHS,\n", 1778 | " batch_size=BATCH_SIZE,\n", 1779 | " validation_split=TEST_SPLIT,\n", 1780 | " callbacks=callbacks)\n", 1781 | " \n", 1782 | " model = tf.keras.models.load_model('models/cnn_tagger.weights.hdf5')\n", 1783 | " metrics = model.evaluate(X_test, y_test)\n", 1784 | " \n", 1785 | " log = {}\n", 1786 | " for name, val in zip(model.metrics_names, metrics):\n", 1787 | "\n", 1788 | " repeat_name, py_val = lib.utils.fix_metric(name, val)\n", 1789 | " log[repeat_name] = py_val\n", 1790 | "\n", 1791 | " # Add a name and sequence number and an F1 score\n", 1792 | " log.update({'name': title})\n", 1793 | " log.update({'sequence': SEQUENCE})\n", 1794 | " log.update({'f1': (log['precision'] * log['recall']) / (log['precision'] + log['recall'])})\n", 1795 | "\n", 1796 | " simple_log.append(log)\n", 1797 | "\n", 1798 | " # Overwrite the old log\n", 1799 | " with open('simple_log.jsonl', 'w') as f:\n", 1800 | " [f.write(json.dumps(l) + '\\n') for l in simple_log]\n", 1801 | "\n", 1802 | "pd.DataFrame(simple_log)" 1803 | ] 1804 | }, 1805 | { 1806 | "cell_type": "markdown", 1807 | "metadata": {}, 1808 | "source": [ 1809 | "## Plot the Epoch Accuracy\n", 1810 | "\n", 1811 | "We want to know the performance at each epoch so that we don't train needlessly large numbers of epochs. " 1812 | ] 1813 | }, 1814 | { 1815 | "cell_type": "code", 1816 | "execution_count": null, 1817 | "metadata": {}, 1818 | "outputs": [], 1819 | "source": [ 1820 | "%matplotlib inline\n", 1821 | "\n", 1822 | "new_history = {}\n", 1823 | "for key, metrics in history.history.items():\n", 1824 | " new_history[lib.utils.fix_metric_name(key)] = metrics\n", 1825 | "\n", 1826 | "import matplotlib.pyplot as plt\n", 1827 | "\n", 1828 | "\n", 1829 | "# summarize history for accuracy\n", 1830 | "fig = plt.gcf()\n", 1831 | "fig.set_size_inches(12, 8, forward=True)\n", 1832 | "\n", 1833 | "viz_keys = ['val_categorical_accuracy', 'val_precision', 'val_recall']\n", 1834 | "for key in viz_keys:\n", 1835 | " plt.plot(new_history[key])\n", 1836 | "plt.title('model accuracy')\n", 1837 | "plt.ylabel('metric')\n", 1838 | "plt.xlabel('epoch')\n", 1839 | "plt.legend(viz_keys, loc='upper left')\n", 1840 | "plt.show()\n", 1841 | "\n", 1842 | "\n", 1843 | "# summarize history for loss\n", 1844 | "fig = plt.gcf()\n", 1845 | "fig.set_size_inches(12, 8, forward=True)\n", 1846 | "\n", 1847 | "plt.plot(history.history['loss'])\n", 1848 | "plt.plot(history.history['val_loss'])\n", 1849 | "plt.title('model loss')\n", 1850 | "plt.ylabel('loss')\n", 1851 | "plt.xlabel('epoch')\n", 1852 | "plt.legend(['train', 'test'], loc='upper left')\n", 1853 | "plt.show()" 1854 | ] 1855 | }, 1856 | { 1857 | "cell_type": "markdown", 1858 | "metadata": {}, 1859 | "source": [ 1860 | "## Train a Kim-CNN Model to Label Stack Overflow Questions\n", 1861 | "\n", 1862 | "Once again we’re ready to train a model to classify/label questions with tag categories. The model is based on [Kim-CNN](https://arxiv.org/abs/1408.5882), a commonly used convolutional neural network for sentence and document classification. We use the functional API and we’ve heavily parametrized the code so as to facilitate experimentation. " 1863 | ] 1864 | }, 1865 | { 1866 | "cell_type": "code", 1867 | "execution_count": null, 1868 | "metadata": {}, 1869 | "outputs": [], 1870 | "source": [ 1871 | "from tensorflow.keras.callbacks import ReduceLROnPlateau, EarlyStopping, ModelCheckpoint\n", 1872 | "from tensorflow.keras.initializers import RandomUniform\n", 1873 | "from tensorflow.keras.layers import (\n", 1874 | " Dense, Activation, Embedding, Flatten, MaxPool1D, GlobalMaxPool1D, \n", 1875 | " Dropout, Conv1D, Input, concatenate, Reshape\n", 1876 | ")\n", 1877 | "from tensorflow.keras.losses import binary_crossentropy, kld\n", 1878 | "from tensorflow.keras.models import Model\n", 1879 | "from tensorflow.keras.optimizers import Adam\n", 1880 | "\n", 1881 | "# from keras_radam import RAdam\n", 1882 | "\n", 1883 | "tf.compat.v1.disable_eager_execution()\n", 1884 | "\n", 1885 | "EXPERIMENT_NAME = 'kim_cnn_2000_3_4_5_7_again_2'\n", 1886 | "\n", 1887 | "FILTER_COUNT = 128\n", 1888 | "FILTER_SIZE = [3, 4, 5, 7]\n", 1889 | "EPOCHS = 8\n", 1890 | "ACTIVATION = 'selu'\n", 1891 | "CONV_PADDING = 'same'\n", 1892 | "EMBED_SIZE = 50\n", 1893 | "EMBED_DROPOUT_RATIO = 0.1\n", 1894 | "CONV_DROPOUT_RATIO = 0.1\n", 1895 | "\n", 1896 | "if len(simple_log) > 0 and EXPERIMENT_NAME == simple_log[-1]['name']:\n", 1897 | " print('RENAME YOUR EXPERIMENT')\n", 1898 | " raise Exception('RENAME YOUR EXPERIMENT')\n", 1899 | "\n", 1900 | "SEQUENCE += 1\n", 1901 | "\n", 1902 | "# # Weights and Biases Monitoring\n", 1903 | "# import wandb\n", 1904 | "# from wandb.keras import WandbCallback\n", 1905 | "# wandb.init(project=\"weakly-supervised-learning\", name=EXPERIMENT_NAME)\n", 1906 | "# config = wandb.config\n", 1907 | "\n", 1908 | "# config.update(\n", 1909 | "# {\n", 1910 | "# 'name': EXPERIMENT_NAME,\n", 1911 | "# 'embedding': 'own',\n", 1912 | "# 'architecture': 'Kim CNN',\n", 1913 | "# 'epochs': EPOCHS,\n", 1914 | "# 'batch_size': BATCH_SIZE,\n", 1915 | "# 'filter_count': FILTER_COUNT,\n", 1916 | "# 'filter_size': FILTER_SIZE,\n", 1917 | "# 'activation': ACTIVATION,\n", 1918 | "# 'conv_padding': CONV_PADDING,\n", 1919 | "# 'sequence': SEQUENCE\n", 1920 | "# }\n", 1921 | "# )\n", 1922 | "\n", 1923 | "padded_input = Input(\n", 1924 | " shape=(X_train.shape[1],),\n", 1925 | " dtype='int32'\n", 1926 | ")\n", 1927 | "\n", 1928 | "emb = Embedding(\n", 1929 | " TOKEN_COUNT, \n", 1930 | " EMBED_SIZE,\n", 1931 | " embeddings_initializer=RandomUniform(),\n", 1932 | " input_length=X_train.shape[1]\n", 1933 | ")(padded_input)\n", 1934 | "# emb = Embedding(\n", 1935 | "# TOKEN_COUNT,\n", 1936 | "# EMBED_SIZE,\n", 1937 | "# weights=[embedding_matrix],\n", 1938 | "# input_length=MAX_LEN,\n", 1939 | "# trainable=True,\n", 1940 | "# )(padded_input)\n", 1941 | "drp = Dropout(0.1)(emb)\n", 1942 | "\n", 1943 | "# Create convlutions of different sizes\n", 1944 | "convs = []\n", 1945 | "for filter_size in FILTER_SIZE:\n", 1946 | " f_conv = Conv1D(\n", 1947 | " filters=FILTER_COUNT,\n", 1948 | " kernel_size=filter_size,\n", 1949 | " padding=CONV_PADDING,\n", 1950 | " activation=ACTIVATION\n", 1951 | " )(drp)\n", 1952 | " f_shape = Reshape((MAX_LEN * EMBED_SIZE, 1))(f_conv)\n", 1953 | " # f_pool = GlobalMaxPool1D()(f_shape)\n", 1954 | " f_pool = MaxPool1D(filter_size)(f_conv)\n", 1955 | " convs.append(f_pool)\n", 1956 | "\n", 1957 | "l_merge = concatenate(convs, axis=1)\n", 1958 | "l_conv = Conv1D(\n", 1959 | " 128,\n", 1960 | " 5,\n", 1961 | " activation=ACTIVATION\n", 1962 | ")(l_merge)\n", 1963 | "l_pool = GlobalMaxPool1D()(l_conv)\n", 1964 | "l_flat = Flatten()(l_pool)\n", 1965 | "l_drp = Dropout(CONV_DROPOUT_RATIO)(l_flat)\n", 1966 | "l_dense = Dense(\n", 1967 | " 60,\n", 1968 | " activation=ACTIVATION\n", 1969 | ")(l_drp)\n", 1970 | "out_dense = Dense(\n", 1971 | " y_train.shape[1],\n", 1972 | " activation='sigmoid'\n", 1973 | ")(l_dense)\n", 1974 | "\n", 1975 | "model = Model(inputs=padded_input, outputs=out_dense)\n", 1976 | "\n", 1977 | "model.compile(\n", 1978 | " optimizer='adam',\n", 1979 | " loss='binary_crossentropy',\n", 1980 | " metrics=[\n", 1981 | " tf.keras.metrics.CategoricalAccuracy(),\n", 1982 | " tf.keras.metrics.Precision(),\n", 1983 | " tf.keras.metrics.Recall(),\n", 1984 | " tf.keras.metrics.AUC(),\n", 1985 | " tf.keras.metrics.TruePositives(),\n", 1986 | " tf.keras.metrics.FalsePositives(),\n", 1987 | " tf.keras.metrics.TrueNegatives(),\n", 1988 | " tf.keras.metrics.FalseNegatives(),\n", 1989 | " ]\n", 1990 | ")\n", 1991 | "model.summary()\n", 1992 | "\n", 1993 | "callbacks = [\n", 1994 | " ReduceLROnPlateau(\n", 1995 | " monitor='val_categorical_accuracy',\n", 1996 | " factor=0.1,\n", 1997 | " patience=1,\n", 1998 | " verbose=1,\n", 1999 | " ), \n", 2000 | " EarlyStopping(\n", 2001 | " monitor='val_categorical_accuracy',\n", 2002 | " patience=2,\n", 2003 | " verbose=1,\n", 2004 | " ), \n", 2005 | " ModelCheckpoint(\n", 2006 | " filepath='models/cnn_tagger.weights.hdf5',\n", 2007 | " monitor='val_categorical_accuracy',\n", 2008 | " save_best_only=True,\n", 2009 | " verbose=1,\n", 2010 | " ),\n", 2011 | " # WandbCallback()\n", 2012 | "]\n", 2013 | "\n", 2014 | "history = model.fit(X_train, y_train,\n", 2015 | " class_weight=train_class_weights,\n", 2016 | " epochs=EPOCHS,\n", 2017 | " batch_size=BATCH_SIZE,\n", 2018 | " validation_data=(X_test, y_test),\n", 2019 | " callbacks=callbacks)" 2020 | ] 2021 | }, 2022 | { 2023 | "cell_type": "code", 2024 | "execution_count": null, 2025 | "metadata": {}, 2026 | "outputs": [], 2027 | "source": [ 2028 | "model = tf.keras.models.load_model('models/cnn_tagger.weights.hdf5')\n", 2029 | "metrics = model.evaluate(X_test, y_test)" 2030 | ] 2031 | }, 2032 | { 2033 | "cell_type": "code", 2034 | "execution_count": null, 2035 | "metadata": {}, 2036 | "outputs": [], 2037 | "source": [ 2038 | "log = {}\n", 2039 | "for name, val in zip(model.metrics_names, metrics):\n", 2040 | " \n", 2041 | " repeat_name, py_val = lib.utils.fix_metric(name, val)\n", 2042 | " log[repeat_name] = py_val\n", 2043 | "\n", 2044 | "# Add a name and sequence number and an F1 score\n", 2045 | "log.update({'name': EXPERIMENT_NAME})\n", 2046 | "log.update({'sequence': SEQUENCE})\n", 2047 | "log.update({'f1': (log['precision'] * log['recall']) / (log['precision'] + log['recall'])})\n", 2048 | "\n", 2049 | "simple_log.append(log)\n", 2050 | "\n", 2051 | "# Overwrite the old log\n", 2052 | "with open('simple_log.jsonl', 'w') as f:\n", 2053 | " [f.write(json.dumps(l) + '\\n') for l in simple_log]\n", 2054 | "\n", 2055 | "pd.DataFrame([log])" 2056 | ] 2057 | }, 2058 | { 2059 | "cell_type": "code", 2060 | "execution_count": null, 2061 | "metadata": {}, 2062 | "outputs": [], 2063 | "source": [ 2064 | "%matplotlib inline\n", 2065 | "\n", 2066 | "new_history = {}\n", 2067 | "for key, metrics in history.history.items():\n", 2068 | " new_history[lib.utils.fix_metric_name(key)] = metrics\n", 2069 | "\n", 2070 | "import matplotlib.pyplot as plt\n", 2071 | "\n", 2072 | "fig = plt.gcf()\n", 2073 | "fig.set_size_inches(12, 8, forward=True)\n", 2074 | "\n", 2075 | "viz_keys = ['val_categorical_accuracy', 'val_precision', 'val_recall']\n", 2076 | "# summarize history for accuracy\n", 2077 | "for key in viz_keys:\n", 2078 | " plt.plot(new_history[key])\n", 2079 | "plt.title('model accuracy')\n", 2080 | "plt.ylabel('metric')\n", 2081 | "plt.xlabel('epoch')\n", 2082 | "plt.legend(viz_keys, loc='upper left')\n", 2083 | "plt.show()\n", 2084 | "\n", 2085 | "fig = plt.gcf()\n", 2086 | "fig.set_size_inches(12, 8, forward=True)\n", 2087 | "\n", 2088 | "# summarize history for loss\n", 2089 | "plt.plot(history.history['loss'])\n", 2090 | "plt.plot(history.history['val_loss'])\n", 2091 | "plt.title('model loss')\n", 2092 | "plt.ylabel('loss')\n", 2093 | "plt.xlabel('epoch')\n", 2094 | "plt.legend(['train', 'test'], loc='upper left')\n", 2095 | "plt.show()" 2096 | ] 2097 | }, 2098 | { 2099 | "cell_type": "markdown", 2100 | "metadata": {}, 2101 | "source": [ 2102 | "## Compare this Run to the 1st and Previous Run\n", 2103 | "\n", 2104 | "To get an idea of performance we need to see where we started and where we just came from." 2105 | ] 2106 | }, 2107 | { 2108 | "cell_type": "code", 2109 | "execution_count": null, 2110 | "metadata": {}, 2111 | "outputs": [], 2112 | "source": [ 2113 | "# Compare to original\n", 2114 | "if len(simple_log) > 1:\n", 2115 | " d2 = simple_log[-1]\n", 2116 | " d1 = simple_log[0]\n", 2117 | "else:\n", 2118 | " d1 = simple_log[0]\n", 2119 | " d2 = simple_log[0]\n", 2120 | "log_diff_1 = {key: d2.get(key, 0) - d1.get(key, 0) for key in d1.keys() if key not in ['name', 'sequence']}\n", 2121 | "log_diff_1['current'] = d2['name']\n", 2122 | "log_diff_1['previous'] = d1['name']\n", 2123 | "\n", 2124 | "# Compare to last run\n", 2125 | "if len(simple_log) > 1:\n", 2126 | " d1 = simple_log[-2]\n", 2127 | " d2 = simple_log[-1]\n", 2128 | "else:\n", 2129 | " d1 = simple_log[0]\n", 2130 | " d2 = simple_log[0]\n", 2131 | " \n", 2132 | "log_diff_2 = {key: d2.get(key, 0) - d1.get(key, 0) for key in d1.keys() if key not in ['name', 'sequence']}\n", 2133 | "log_diff_2['current'] = d2['name']\n", 2134 | "log_diff_2['previous'] = d1['name']\n", 2135 | "\n", 2136 | "df = pd.DataFrame.from_dict([log_diff_1, log_diff_2])\n", 2137 | "cols = df.columns.tolist()\n", 2138 | "cols.remove('previous')\n", 2139 | "cols.remove('current')\n", 2140 | "show_cols = ['previous', 'current'] + cols\n", 2141 | "df[show_cols]" 2142 | ] 2143 | }, 2144 | { 2145 | "cell_type": "markdown", 2146 | "metadata": {}, 2147 | "source": [ 2148 | "## View the Last 10 Experiments\n", 2149 | "\n", 2150 | "It can be helpful to see trends of performance among experiments." 2151 | ] 2152 | }, 2153 | { 2154 | "cell_type": "code", 2155 | "execution_count": null, 2156 | "metadata": {}, 2157 | "outputs": [], 2158 | "source": [ 2159 | "log_df = pd.DataFrame(simple_log)\n", 2160 | "log_df['f1'] = (log_df['precision'] * log_df['recall']) / (log_df['precision'] + log_df['recall'])\n", 2161 | "\n", 2162 | "log_df[[\n", 2163 | " 'sequence',\n", 2164 | " 'name',\n", 2165 | " 'loss',\n", 2166 | " 'categorical_accuracy',\n", 2167 | " 'precision',\n", 2168 | " 'recall',\n", 2169 | " 'f1',\n", 2170 | " 'auc',\n", 2171 | " 'true_positives',\n", 2172 | " 'false_positives',\n", 2173 | " 'true_negatives',\n", 2174 | " 'false_negatives',\n", 2175 | " 'hinge',\n", 2176 | " 'mean_absolute_error',\n", 2177 | "]][0:10 if len(log_df) > 9 else len(log_df)]" 2178 | ] 2179 | }, 2180 | { 2181 | "cell_type": "markdown", 2182 | "metadata": {}, 2183 | "source": [ 2184 | "## Check the Actual Prediction Outputs\n", 2185 | "\n", 2186 | "It is not enough to know theoretical performance. We need to see the actual output of the tagger at different confidence thresholds." 2187 | ] 2188 | }, 2189 | { 2190 | "cell_type": "code", 2191 | "execution_count": null, 2192 | "metadata": {}, 2193 | "outputs": [], 2194 | "source": [ 2195 | "TEST_COUNT = 1000\n", 2196 | "\n", 2197 | "X_test_text = tokenizer.sequences_to_texts(X_test[:TEST_COUNT])\n", 2198 | "\n", 2199 | "y_test_tags = []\n", 2200 | "for row in y_test[:TEST_COUNT].tolist():\n", 2201 | " tags = [index_tag[str(i)] for i, col in enumerate(row) if col == 1]\n", 2202 | " y_test_tags.append(tags)" 2203 | ] 2204 | }, 2205 | { 2206 | "cell_type": "markdown", 2207 | "metadata": {}, 2208 | "source": [ 2209 | "## Adjust the threshold for classification\n", 2210 | "\n", 2211 | "This lets us see how well the model generalizes to labeling more classes." 2212 | ] 2213 | }, 2214 | { 2215 | "cell_type": "code", 2216 | "execution_count": null, 2217 | "metadata": {}, 2218 | "outputs": [], 2219 | "source": [ 2220 | "CLASSIFY_THRESHOLD = 0.5\n", 2221 | "\n", 2222 | "y_pred = model.predict(X_test)\n", 2223 | "y_pred = (y_pred > CLASSIFY_THRESHOLD) * 1\n", 2224 | "\n", 2225 | "y_pred_tags = []\n", 2226 | "for row in y_pred[:TEST_COUNT].tolist():\n", 2227 | " tags = [index_tag[str(i)] for i, col in enumerate(row) if col > CLASSIFY_THRESHOLD]\n", 2228 | " y_pred_tags.append(tags)" 2229 | ] 2230 | }, 2231 | { 2232 | "cell_type": "markdown", 2233 | "metadata": {}, 2234 | "source": [ 2235 | "## See How Far off we are per Class" 2236 | ] 2237 | }, 2238 | { 2239 | "cell_type": "code", 2240 | "execution_count": null, 2241 | "metadata": {}, 2242 | "outputs": [], 2243 | "source": [ 2244 | "np.around(y_pred, 0).sum(axis=0) - y_test.sum(axis=0)" 2245 | ] 2246 | }, 2247 | { 2248 | "cell_type": "markdown", 2249 | "metadata": {}, 2250 | "source": [ 2251 | "### View Prediction Results\n", 2252 | "\n", 2253 | "It is better to view the results in a DataFrame." 2254 | ] 2255 | }, 2256 | { 2257 | "cell_type": "code", 2258 | "execution_count": null, 2259 | "metadata": {}, 2260 | "outputs": [], 2261 | "source": [ 2262 | "prediction_tests = []\n", 2263 | "for x, y, z in zip(X_test_text, y_pred_tags, y_test_tags):\n", 2264 | " prediction_tests.append({\n", 2265 | " 'Question': x,\n", 2266 | " 'Actual': ' '.join(sorted(z)),\n", 2267 | " 'Predictions': ' '.join(sorted(y)),\n", 2268 | " })\n", 2269 | "\n", 2270 | "pd.set_option('display.max_colwidth', 300)\n", 2271 | "pd.DataFrame(prediction_tests)[['Question', 'Actual', 'Predictions']]" 2272 | ] 2273 | }, 2274 | { 2275 | "cell_type": "markdown", 2276 | "metadata": {}, 2277 | "source": [ 2278 | "## The Big Finish\n", 2279 | "\n", 2280 | "That is the big finish!" 2281 | ] 2282 | } 2283 | ], 2284 | "metadata": { 2285 | "kernelspec": { 2286 | "display_name": "Python 3", 2287 | "language": "python", 2288 | "name": "python3" 2289 | }, 2290 | "language_info": { 2291 | "codemirror_mode": { 2292 | "name": "ipython", 2293 | "version": 3 2294 | }, 2295 | "file_extension": ".py", 2296 | "mimetype": "text/x-python", 2297 | "name": "python", 2298 | "nbconvert_exporter": "python", 2299 | "pygments_lexer": "ipython3", 2300 | "version": "3.7.4" 2301 | }, 2302 | "toc": { 2303 | "base_numbering": 1, 2304 | "nav_menu": {}, 2305 | "number_sections": false, 2306 | "sideBar": false, 2307 | "skip_h1_title": false, 2308 | "title_cell": "Table of Contents", 2309 | "title_sidebar": "Contents", 2310 | "toc_cell": false, 2311 | "toc_position": {}, 2312 | "toc_section_display": false, 2313 | "toc_window_display": false 2314 | } 2315 | }, 2316 | "nbformat": 4, 2317 | "nbformat_minor": 2 2318 | } 2319 | -------------------------------------------------------------------------------- /ch04/PREREQUISITES.md: -------------------------------------------------------------------------------- 1 | # Chapter 3 Prerequisites 2 | 3 | In order to work the Github BERT Embedding examples, you will have to download, build and install the following software: 4 | 5 | ## Google SentencePiece 6 | 7 | ```bash 8 | git clone https://github.com/google/sentencepiece 9 | cd sentencepiece 10 | 11 | mkdir build 12 | cd build 13 | cmake .. 14 | make -j $(nproc) 15 | sudo make install 16 | sudo ldconfig -v 17 | ``` 18 | 19 | 20 | 21 | ## Google BERT 22 | 23 | git clone https://github.com/google-research/bert 24 | cd bert 25 | 26 | conda create -n bert -y python=3.7.5 27 | conda install -y pip 28 | pip install tensorflow-gpu==1.14.0 29 | 30 | python bert/create_pretraining_data.py \ 31 | --input_file=data/sentences.csv \ 32 | --output_file=data/tf_examples.tfrecord \ 33 | --vocab_file=models/wsl.vocab \ 34 | --bert_config_file=./bert/bert_config.json \ 35 | --do_lower_case=False \ 36 | --max_seq_length=128 \ 37 | --max_predictions_per_seq=20 \ 38 | --num_train_steps=20 \ 39 | --num_warmup_steps=10 \ 40 | --random_seed=1337 \ 41 | --learning_rate=2e-5 42 | 43 | conda deactivate -------------------------------------------------------------------------------- /ch04/bert.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Use sentencepiece to create a WordPiece vocabulary for the data 4 | git clone https://github.com/google/sentencepiece 5 | cd sentencepiece 6 | mkdir build 7 | cd build 8 | cmake .. 9 | make -j $(nproc) 10 | sudo make install 11 | sudo ldconfig -v 12 | 13 | 14 | VOCAB_SIZE=200000 15 | 16 | spm_train --input=./sentences.csv --model_prefix=wsl --vocab_size=${VOCAB_SIZE} 17 | 18 | git clone https://github.com/google-research/bert 19 | -------------------------------------------------------------------------------- /ch05/Distant Supervision.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Sources of Distant Supervision for NER/IE from Stack Overflow Posts\n", 8 | "\n", 9 | "In this notebook we will be acquiring sources of distant supervision for our Information Extraction models using SPARQL queries on the WikiData dataset.\n", 10 | "\n", 11 | "## WikiData Programming Languages\n", 12 | "\n", 13 | "For the Snorkel example for Chapter 5, we create a programming language extractor from the titles and bodies of Stack Overflow questions. Here we generate the file that we used by querying WikiData using SPARQL to get a list of programming languages. We then use these language names to label positive examples of programming languages in posts for training our discriminative/network extractor model.\n", 14 | "\n", 15 | "The following SPARQL query prints out the names of all [Property:31:instances of](https://www.wikidata.org/wiki/Property:P31) [Item:Q9143 programming languages](https://www.wikidata.org/wiki/Q9143) in English content from WikiData.\n", 16 | "\n", 17 | "We `SELECT DISTINCT` the item and item labels, then filter the language of the item label to English, to avoid duplicate content from other languages." 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": 1, 23 | "metadata": {}, 24 | "outputs": [], 25 | "source": [ 26 | "!pip install -q jsonlines requests\n", 27 | "\n", 28 | "import json\n", 29 | "import jsonlines\n", 30 | "import requests" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": 2, 36 | "metadata": {}, 37 | "outputs": [], 38 | "source": [ 39 | "url = 'https://query.wikidata.org/sparql'\n", 40 | "query = \"\"\"\n", 41 | "# Get all programming language names from English sources\n", 42 | "SELECT DISTINCT ?item ?item_label\n", 43 | "WHERE {\n", 44 | " ?item wdt:P31 wd:Q9143 # P31:instances of Q9143:programming language\n", 45 | " ; rdfs:label ?item_label .\n", 46 | " \n", 47 | " FILTER (LANG(?item_label) = \"en\"). # English only\n", 48 | "}\n", 49 | "\"\"\"\n", 50 | "r = requests.get(url, params = {'format': 'json', 'query': query})\n", 51 | "data = r.json()" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": 6, 57 | "metadata": {}, 58 | "outputs": [ 59 | { 60 | "name": "stdout", 61 | "output_type": "stream", 62 | "text": [ 63 | "[\n", 64 | " {\n", 65 | " \"item\": {\n", 66 | " \"type\": \"uri\",\n", 67 | " \"value\": \"http://www.wikidata.org/entity/Q2005\"\n", 68 | " },\n", 69 | " \"item_label\": {\n", 70 | " \"type\": \"literal\",\n", 71 | " \"value\": \"JavaScript\",\n", 72 | " \"xml:lang\": \"en\"\n", 73 | " }\n", 74 | " },\n", 75 | " {\n", 76 | " \"item\": {\n", 77 | " \"type\": \"uri\",\n", 78 | " \"value\": \"http://www.wikidata.org/entity/Q1374139\"\n", 79 | " },\n", 80 | " \"item_label\": {\n", 81 | " \"type\": \"literal\",\n", 82 | " \"value\": \"Euphoria\",\n", 83 | " \"xml:lang\": \"en\"\n", 84 | " }\n", 85 | " },\n", 86 | " {\n", 87 | " \"item\": {\n", 88 | " \"type\": \"uri\",\n", 89 | " \"value\": \"http://www.wikidata.org/entity/Q1334586\"\n", 90 | " },\n", 91 | " \"item_label\": {\n", 92 | " \"type\": \"literal\",\n", 93 | " \"value\": \"Emacs Lisp\",\n", 94 | " \"xml:lang\": \"en\"\n", 95 | " }\n", 96 | " },\n", 97 | " {\n", 98 | " \"item\": {\n", 99 | " \"type\": \"uri\",\n", 100 | " \"value\": \"http://www.wikidata.org/entity/Q1356671\"\n", 101 | " },\n", 102 | " \"item_label\": {\n", 103 | " \"type\": \"literal\",\n", 104 | " \"value\": \"GT.M\",\n", 105 | " \"xml:lang\": \"en\"\n", 106 | " }\n", 107 | " }\n", 108 | "]\n" 109 | ] 110 | } 111 | ], 112 | "source": [ 113 | "print(json.dumps(data[\"results\"][\"bindings\"][0:4], indent=4, sort_keys=True))" 114 | ] 115 | }, 116 | { 117 | "cell_type": "markdown", 118 | "metadata": {}, 119 | "source": [ 120 | "## Extract the Language Labels from nested JSON\n", 121 | "\n", 122 | "Nested JSON is a pain to work with in `DataFrames`, so we un-nest it, retaining only what we need." 123 | ] 124 | }, 125 | { 126 | "cell_type": "code", 127 | "execution_count": 8, 128 | "metadata": {}, 129 | "outputs": [ 130 | { 131 | "name": "stdout", 132 | "output_type": "stream", 133 | "text": [ 134 | "There were 1,417 languages returned.\n", 135 | "\n", 136 | "{'name': 'JavaScript', 'kb_url': 'http://www.wikidata.org/entity/Q2005', 'kb_id': 'Q2005'}\n", 137 | "{'name': 'Euphoria', 'kb_url': 'http://www.wikidata.org/entity/Q1374139', 'kb_id': 'Q1374139'}\n", 138 | "{'name': 'Emacs Lisp', 'kb_url': 'http://www.wikidata.org/entity/Q1334586', 'kb_id': 'Q1334586'}\n", 139 | "{'name': 'GT.M', 'kb_url': 'http://www.wikidata.org/entity/Q1356671', 'kb_id': 'Q1356671'}\n", 140 | "{'name': 'REBOL', 'kb_url': 'http://www.wikidata.org/entity/Q1359171', 'kb_id': 'Q1359171'}\n", 141 | "{'name': 'Embedded SQL', 'kb_url': 'http://www.wikidata.org/entity/Q1335009', 'kb_id': 'Q1335009'}\n", 142 | "{'name': 'SystemVerilog', 'kb_url': 'http://www.wikidata.org/entity/Q1387402', 'kb_id': 'Q1387402'}\n", 143 | "{'name': 'BETA', 'kb_url': 'http://www.wikidata.org/entity/Q830842', 'kb_id': 'Q830842'}\n", 144 | "{'name': 'newLISP', 'kb_url': 'http://www.wikidata.org/entity/Q827233', 'kb_id': 'Q827233'}\n", 145 | "{'name': 'Verilog', 'kb_url': 'http://www.wikidata.org/entity/Q827773', 'kb_id': 'Q827773'}\n" 146 | ] 147 | } 148 | ], 149 | "source": [ 150 | "languages = [\n", 151 | " {\n", 152 | " 'name': x['item_label']['value'],\n", 153 | " 'kb_url': x['item']['value'],\n", 154 | " 'kb_id': x['item']['value'].split('/')[-1], # Get the ID\n", 155 | " }\n", 156 | " for x in data['results']['bindings']\n", 157 | "]\n", 158 | "\n", 159 | "# Filter out an erroneous language\n", 160 | "languages = list(\n", 161 | " filter(\n", 162 | " lambda x: x['kb_id'] != 'Q25111344', \n", 163 | " languages\n", 164 | " )\n", 165 | ")\n", 166 | "\n", 167 | "print(f'There were {len(languages):,} languages returned.\\n')\n", 168 | "\n", 169 | "for l in languages[0:10]:\n", 170 | " print(l)" 171 | ] 172 | }, 173 | { 174 | "cell_type": "markdown", 175 | "metadata": {}, 176 | "source": [ 177 | "## Write Languages to Disk as CSV" 178 | ] 179 | }, 180 | { 181 | "cell_type": "code", 182 | "execution_count": 10, 183 | "metadata": {}, 184 | "outputs": [], 185 | "source": [ 186 | "with jsonlines.open('programming_languages.jsonl', mode='w') as writer:\n", 187 | " writer.write_all(languages)" 188 | ] 189 | }, 190 | { 191 | "cell_type": "markdown", 192 | "metadata": {}, 193 | "source": [ 194 | "## Now get a list of operating systems to create negative LFs from" 195 | ] 196 | }, 197 | { 198 | "cell_type": "code", 199 | "execution_count": 11, 200 | "metadata": {}, 201 | "outputs": [], 202 | "source": [ 203 | "url = 'https://query.wikidata.org/sparql'\n", 204 | "query = \"\"\"\n", 205 | "# Get all operating system names from English sources\n", 206 | "SELECT DISTINCT ?item ?item_label\n", 207 | "WHERE {\n", 208 | " ?item wdt:P31 wd:Q9135 # instances of operating system\n", 209 | " ; rdfs:label ?item_label .\n", 210 | " \n", 211 | " FILTER (LANG(?item_label) = \"en\"). \n", 212 | "}\n", 213 | "\"\"\"\n", 214 | "r = requests.get(url, params = {'format': 'json', 'query': query})\n", 215 | "data = r.json()" 216 | ] 217 | }, 218 | { 219 | "cell_type": "code", 220 | "execution_count": 12, 221 | "metadata": {}, 222 | "outputs": [ 223 | { 224 | "name": "stdout", 225 | "output_type": "stream", 226 | "text": [ 227 | "There were 1,066 programs returned.\n", 228 | "\n", 229 | "{'name': 'Windows 8', 'kb_url': 'http://www.wikidata.org/entity/Q5046', 'kb_id': 'Q5046'}\n", 230 | "{'name': 'Möbius', 'kb_url': 'http://www.wikidata.org/entity/Q3869245', 'kb_id': 'Q3869245'}\n", 231 | "{'name': 'ITIX', 'kb_url': 'http://www.wikidata.org/entity/Q3789886', 'kb_id': 'Q3789886'}\n", 232 | "{'name': 'TinyKRNL', 'kb_url': 'http://www.wikidata.org/entity/Q3991642', 'kb_id': 'Q3991642'}\n", 233 | "{'name': 'Myarc Disk Operating System', 'kb_url': 'http://www.wikidata.org/entity/Q3841260', 'kb_id': 'Q3841260'}\n", 234 | "{'name': 'NX-OS', 'kb_url': 'http://www.wikidata.org/entity/Q3869717', 'kb_id': 'Q3869717'}\n", 235 | "{'name': 'Unslung', 'kb_url': 'http://www.wikidata.org/entity/Q4006074', 'kb_id': 'Q4006074'}\n", 236 | "{'name': 'KnopILS', 'kb_url': 'http://www.wikidata.org/entity/Q3815960', 'kb_id': 'Q3815960'}\n", 237 | "{'name': 'Multiuser DOS', 'kb_url': 'http://www.wikidata.org/entity/Q3867065', 'kb_id': 'Q3867065'}\n", 238 | "{'name': 'MDOS', 'kb_url': 'http://www.wikidata.org/entity/Q3841258', 'kb_id': 'Q3841258'}\n" 239 | ] 240 | } 241 | ], 242 | "source": [ 243 | "oses = [\n", 244 | " {\n", 245 | " 'name': x['item_label']['value'],\n", 246 | " 'kb_url': x['item']['value'],\n", 247 | " 'kb_id': x['item']['value'].split('/')[-1], # Get the ID\n", 248 | " }\n", 249 | " for x in data['results']['bindings']\n", 250 | "]\n", 251 | "\n", 252 | "print(f'There were {len(oses):,} programs returned.\\n')\n", 253 | "\n", 254 | "for l in oses[0:10]:\n", 255 | " print(l)" 256 | ] 257 | }, 258 | { 259 | "cell_type": "code", 260 | "execution_count": 13, 261 | "metadata": {}, 262 | "outputs": [], 263 | "source": [ 264 | "with jsonlines.open('operating_systems.jsonl', mode='w') as writer:\n", 265 | " writer.write_all(oses)" 266 | ] 267 | }, 268 | { 269 | "cell_type": "markdown", 270 | "metadata": {}, 271 | "source": [ 272 | "## Conclusion\n", 273 | "\n", 274 | "Now we are ready to use our programming languages in our Label Functions (LFs) in the Snorkel notebook!" 275 | ] 276 | } 277 | ], 278 | "metadata": { 279 | "kernelspec": { 280 | "display_name": "Python 3", 281 | "language": "python", 282 | "name": "python3" 283 | }, 284 | "language_info": { 285 | "codemirror_mode": { 286 | "name": "ipython", 287 | "version": 3 288 | }, 289 | "file_extension": ".py", 290 | "mimetype": "text/x-python", 291 | "name": "python", 292 | "nbconvert_exporter": "python", 293 | "pygments_lexer": "ipython3", 294 | "version": "3.7.4" 295 | } 296 | }, 297 | "nbformat": 4, 298 | "nbformat_minor": 2 299 | } 300 | -------------------------------------------------------------------------------- /ch05/bad_tags.spark.py: -------------------------------------------------------------------------------- 1 | # Seperate out each group of bag tags - documents and their single label as gold standard examples for weak supervision 2 | 3 | from pyspark.sql import SparkSession, Row 4 | import pyspark.sql.functions as F 5 | import pyspark.sql.types as T 6 | 7 | from lib.utils import one_hot_encode 8 | 9 | PATHS = { 10 | 'bad_questions': { 11 | 'local': 'data/stackoverflow/Questions.Bad.{}.{}.parquet', 12 | 's3': 's3://stackoverflow-events/Questions.Bad.{}.{}.parquet', 13 | }, 14 | 'bad_tag_counts': { 15 | 'local': 'data/stackoverflow/TagCounts.Bad.{}.{}.parquet', 16 | 's3': 's3://stackoverflow-events/TagCounts.Bad.{}.{}.parquet', 17 | }, 18 | 'one_hot': { 19 | 'local': 'data/stackoverflow/Questions.Bad.OneHot.{}.{}.parquet', 20 | 's3': 's3://stackoverflow-events/Questions.Bad.OneHot.{}.{}.parquet', 21 | }, 22 | 'final_tag_examples': { 23 | 'local': 'data/stackoverflow/PerTag.Bad.{}.{}.jsonl/{}.{}.jsonl', 24 | 's3': 's3://stackoverflow-events/PerTag.Bad.{}.{}.jsonl/{}.{}.jsonl', 25 | }, 26 | 'final_tag_all': { 27 | 'local': 'data/stackoverflow/PerTag.Bad.{}.{}.jsonl/*', 28 | 's3': 's3://stackoverflow-events/PerTag.Bad.{}.{}.jsonl/*', 29 | }, 30 | 'final_tag_parquet': { 31 | 'local': 'data/stackoverflow/PerTag.Bad.{}.{}.parquet', 32 | 's3': 's3://stackoverflow-events/PerTag.Bad.{}.{}.parquet', 33 | } 34 | } 35 | 36 | # Define a set of paths for each step for local and S3 37 | PATH_SET = 'local' # 's3' 38 | 39 | spark = SparkSession.builder\ 40 | .appName('Deep Products - Sample JSON')\ 41 | .config('spark.dynamicAllocation.enabled', True)\ 42 | .config('spark.shuffle.service.enabled', True)\ 43 | .getOrCreate() 44 | sc = spark.sparkContext 45 | 46 | tag_limit, stratify_limit, bad_limit = 2000, 2000, 500 47 | 48 | # Load the questions with tags occurring between 2000 - 500 times (note: does not include more numerous tags) 49 | bad_df = spark.read.parquet(PATHS['bad_questions'][PATH_SET].format(tag_limit, bad_limit)) 50 | 51 | # 52 | # Count the instances of each bad tag 53 | # 54 | all_tags = bad_df.rdd.flatMap(lambda x: x['_Tags']) 55 | tag_counts_df = all_tags\ 56 | .groupBy(lambda x: x)\ 57 | .map(lambda x: Row(tag=x[0], total=len(x[1])))\ 58 | .toDF()\ 59 | .select('tag', 'total').orderBy(['total'], ascending=False) 60 | 61 | tag_counts_df.write.mode('overwrite').parquet( 62 | PATHS['bad_tag_counts'][PATH_SET].format(tag_limit, bad_limit) 63 | ) 64 | tag_counts_df = spark.read.parquet( 65 | PATHS['bad_tag_counts'][PATH_SET].format(tag_limit, bad_limit) 66 | ) 67 | tag_total = tag_counts_df.count() 68 | tag_counts_df.show() 69 | 70 | 71 | # 72 | # Create indexes for each multilabel tag 73 | # 74 | enumerated_labels = [ 75 | z for z in enumerate( 76 | sorted( 77 | tag_counts_df.rdd 78 | .groupBy(lambda x: 1) 79 | .flatMap(lambda x: [y.tag for y in x[1]]) 80 | .collect() 81 | ) 82 | ) 83 | ] 84 | tag_index = {x: i for i, x in enumerated_labels} 85 | index_tag = {i: x for i, x in enumerated_labels} 86 | 87 | 88 | # One hot encode the data using one_hot_encode() 89 | one_hot_questions = bad_df.rdd.map( 90 | lambda x: Row( 91 | _Body=x._Body, 92 | _Code=x._Code, 93 | _Tags=one_hot_encode(x._Tags, enumerated_labels, index_tag) 94 | ) 95 | ) 96 | 97 | # Create a DataFrame out of the one-hot encoded RDD 98 | schema = T.StructType( 99 | [ 100 | T.StructField('_Body', T.StringType()), 101 | T.StructField('_Code', T.StringType()), 102 | T.StructField( 103 | '_Tags', 104 | T.ArrayType( 105 | T.IntegerType() 106 | ), 107 | ) 108 | ] 109 | ) 110 | 111 | one_hot_df = spark.createDataFrame( 112 | one_hot_questions, 113 | schema 114 | ) 115 | one_hot_df.show() 116 | 117 | one_hot_df.write.mode('overwrite').parquet(PATHS['one_hot'][PATH_SET].format( 118 | tag_limit, bad_limit 119 | )) 120 | one_hot_df = spark.read.parquet(PATHS['one_hot'][PATH_SET].format( 121 | tag_limit, bad_limit 122 | )) 123 | 124 | 125 | # 126 | # Write out a seperate set of records for each label, where that label is positive 127 | # 128 | for i in range(0, tag_total): 129 | 130 | tag_str = index_tag[i] 131 | print(f'\n\nProcessing tag {tag_str} which is {i:,} of {tag_total:,} total tags\n\n') 132 | 133 | # Select records with a positive value for this tag 134 | positive_examples = one_hot_df.filter(F.col('_Tags')[i] == 1) 135 | 136 | # Select the current label column alone 137 | final_examples = positive_examples.select( 138 | '_Body', 139 | '_Code', 140 | F.lit(tag_str).cast(T.IntegerType()).alias('_Tag'), 141 | F.lit(i).alias('_Index'), 142 | ) 143 | 144 | # Write this tag's examples to a subdirectory as 1 JSON file, so we can load them individually as well as all at 145 | # once later 146 | final_examples.coalesce(1).write.mode('overwrite').json( 147 | PATHS['final_tag_examples'][PATH_SET].format(tag_limit, bad_limit, i, tag_str) 148 | ) 149 | 150 | # Specify a schema to load the JSON 151 | schema = T.StructType([ 152 | T.StructField('_Body', T.StringType()), 153 | T.StructField('_Code', T.StringType()), 154 | T.StructField('_Tag', T.StringType()), 155 | T.StructField('_Index', T.IntegerType()), 156 | ]) 157 | 158 | # Write as one DataFrame 159 | final_examples_all = spark.read.json( 160 | PATHS['final_tag_all'][PATH_SET].format(tag_limit, bad_limit) 161 | ) 162 | final_examples_all.coalesce(20).write.mode('overwrite').parquet( 163 | PATHS['final_tag_parquet'][PATH_SET].format(tag_limit, bad_limit), 164 | ) 165 | -------------------------------------------------------------------------------- /ch05/label.spark.py: -------------------------------------------------------------------------------- 1 | # Use Snorkel and PySpark to create weak labels for the data using Label Functions (LFs) 2 | 3 | from collections import OrderedDict 4 | 5 | import numpy as np 6 | 7 | from pyspark.sql import SparkSession, Row 8 | import pyspark.sql.functions as F 9 | import pyspark.sql.types as T 10 | 11 | from snorkel.labeling.apply.spark import SparkLFApplier 12 | from snorkel.labeling import LabelingFunction 13 | 14 | 15 | # What limits for tag frequency we're working with 16 | TAG_LIMIT, BAD_LIMIT = 2000, 500 17 | 18 | PATHS = { 19 | 'bad_tag_counts': { 20 | 'local': 'data/stackoverflow/TagCounts.Bad.{}.{}.parquet'.format(TAG_LIMIT, BAD_LIMIT), 21 | 's3': 's3://stackoverflow-events/TagCounts.Bad.{}.{}.parquet'.format(TAG_LIMIT, BAD_LIMIT), 22 | }, 23 | 'bad_questions': { 24 | 'local': 'data/stackoverflow/Questions.Bad.{}.{}.parquet'.format(TAG_LIMIT, BAD_LIMIT), 25 | 's3': 's3://stackoverflow-events/Questions.Bad.{}.{}.parquet'.format(TAG_LIMIT, BAD_LIMIT), 26 | }, 27 | 'one_hot': { 28 | 'local': 'data/stackoverflow/Questions.Bad.OneHot.{}.{}.parquet'.format(TAG_LIMIT, BAD_LIMIT), 29 | 's3': 's3://stackoverflow-events/Questions.Bad.OneHot.{}.{}.parquet'.format(TAG_LIMIT, BAD_LIMIT), 30 | }, 31 | 'label_encoded': { 32 | 'local': 'data/stackoverflow/Questions.Bad.LabelEncoded.{}.{}.parquet'.format(TAG_LIMIT, BAD_LIMIT), 33 | 's3': 's3://stackoverflow-events/Questions.Bad.LabelEncoded.{}.{}.parquet'.format(TAG_LIMIT, BAD_LIMIT), 34 | }, 35 | 'weak_labels': 'data/stackoverflow/weak_labels.npy', 36 | } 37 | 38 | # Define a set of paths for each step for local and S3 39 | PATH_SET = 'local' # 's3' 40 | 41 | spark = SparkSession.builder\ 42 | .appName('Deep Products - Create Weak Labels')\ 43 | .config('spark.dynamicAllocation.enabled', True)\ 44 | .config('spark.shuffle.service.enabled', True)\ 45 | .getOrCreate() 46 | sc = spark.sparkContext 47 | 48 | bad_questions = spark.read.parquet( 49 | PATHS['bad_questions'][PATH_SET] 50 | ) 51 | 52 | 53 | # 54 | # Create indexes for each multilabel tag 55 | # 56 | tag_counts_df = spark.read.parquet(PATHS['bad_tag_counts'][PATH_SET]) 57 | enumerated_labels = [ 58 | z for z in enumerate( 59 | sorted( 60 | tag_counts_df.rdd 61 | .groupBy(lambda x: 1) 62 | .flatMap(lambda x: [y.tag for y in x[1]]) 63 | .collect() 64 | ) 65 | ) 66 | ] 67 | tag_index = {x: i for i, x in enumerated_labels} 68 | index_tag = {i: x for i, x in enumerated_labels} 69 | 70 | 71 | # 72 | # Use the indexes to label encode the data 73 | # 74 | def label_encode(x, tag_index): 75 | """Convert from a list of tags to a label encoded value""" 76 | for tag in x._Tags: 77 | yield Row( 78 | _Body=x._Body, 79 | _Code=x._Code, 80 | _Label=tag_index[tag] 81 | ) 82 | 83 | label_encoded = bad_questions.rdd.flatMap( 84 | lambda x: label_encode(x, tag_index) 85 | ) 86 | label_encoded_df = label_encoded.toDF() 87 | label_encoded_df.write.mode('overwrite').parquet(PATHS['label_encoded'][PATH_SET]) 88 | 89 | label_encoded_df = spark.read.parquet(PATHS['label_encoded'][PATH_SET]) 90 | 91 | 92 | # 93 | # Create Label Functions (LFs) for tag search 94 | # 95 | ABSTAIN = -1 96 | 97 | def keyword_lookup(x, keywords, label): 98 | match = any(word in x._Body for word in keywords) 99 | if match: 100 | return label 101 | return ABSTAIN 102 | 103 | def make_keyword_lf(keywords, label=ABSTAIN): 104 | return LabelingFunction( 105 | name=f"keyword_{keywords}", 106 | f=keyword_lookup, 107 | resources=dict(keywords=keywords, label=label), 108 | ) 109 | 110 | # A tag split by dashes '-' which aids the search (I think), ex. html-css 111 | keyword_lfs = OrderedDict() 112 | for i, tag in enumerated_labels: 113 | keyword_lfs[tag] = make_keyword_lf(tag.split('-'), label=i) 114 | 115 | # 116 | # Apply labeling functions to get a set of weak labels 117 | # 118 | spark_applier = SparkLFApplier(list(keyword_lfs.values())) 119 | weak_labels = spark_applier.apply(label_encoded) 120 | 121 | # Save the weak labels numpy array for use locallys 122 | np.save( 123 | PATHS['weak_labels'], 124 | weak_labels 125 | ) 126 | -------------------------------------------------------------------------------- /ch05/split_tags.spark.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # 4 | # This script extracts the text and code of Stack Overflow questions (not answers) in separate fields along with one-hot 5 | # encoded labels (folksonomy tags, 1-5 each question) for records having at least so many occurrences. To run it locally 6 | # set PATH_SET to 'local'. For AWS using PATH_SET of 's3'. 7 | # 8 | # Run me with: PYSPARK_PYTHON=python3 PYSPARK_DRIVER_PYTHON=ipython3 pyspark 9 | # 10 | 11 | import gc 12 | import json 13 | import random 14 | import re 15 | 16 | import boto3 17 | from pyspark.sql import SparkSession, Row 18 | import pyspark.sql.functions as F 19 | import pyspark.sql.types as T 20 | 21 | from lib.utils import ( 22 | create_labeled_schema, create_label_row_columns, extract_text, extract_text_plain, 23 | extract_code_plain, get_indexes, one_hot_encode, 24 | ) 25 | 26 | 27 | # Set the minimum count a tag must occur to be included in our dataset 28 | TAG_LIMIT = 50 29 | 30 | # Set the maximum number of records to sample for each tags 31 | SAMPLE_LIMIT = 500 32 | 33 | # Print debug info as we compute, takes extra time 34 | DEBUG = False 35 | 36 | # Print a report on record/label duplication at the end 37 | REPORT = True 38 | 39 | # Define a set of paths for each step for local and S3 40 | PATH_SET = 'local' # 's3' 41 | 42 | PATHS = { 43 | 's3_bucket': 'stackoverflow-events', 44 | 'posts': { 45 | 'local': 'data/stackoverflow/Posts.df.parquet', 46 | 's3': 's3://stackoverflow-events/08-05-2019/Posts.df.parquet', 47 | }, 48 | 'questions': { 49 | 'local': 'data/stackoverflow/Questions.Answered.parquet', 50 | 's3': 's3://stackoverflow-events/08-05-2019/Questions.Answered.parquet', 51 | }, 52 | 'users_parquet': { 53 | 'local': 'data/stackoverflow/Users.df.parquet', 54 | 's3': 's3://stackoverflow-events/08-05-2019/Users.df.parquet', 55 | }, 56 | 'questions_users': { 57 | 'local': 'data/stackoverflow/QuestionsUsers.df.parquet', 58 | 's3': 's3://stackoverflow-events/08-05-2019/QuestionsUsers.df.parquet', 59 | }, 60 | 'tag_counts': { 61 | 'local': 'data/stackoverflow/Questions.TagCounts.All.parquet', 62 | 's3': 's3://stackoverflow-events/08-05-2019/Questions.TagCounts.All.parquet', 63 | }, 64 | 'questions_tags': { 65 | 'local': 'data/stackoverflow/Questions.Tags.{}.parquet', 66 | 's3': 's3://stackoverflow-events/08-05-2019/Questions.Tags.{}.parquet', 67 | }, 68 | 'per_tag': { 69 | 'local': 'data/stackoverflow/Questions.PerTag.{}.parquet', 70 | 's3': 's3://stackoverflow-events/08-05-2019/Questions.PerTag.{}.parquet', 71 | }, 72 | 'sample_ratios': { 73 | 'local': 'data/stackoverflow/Tag.SampleRatios.{}.parquet', 74 | 's3': 's3://stackoverflow-events/08-05-2019/Tag.SampleRatios.{}.parquet', 75 | }, 76 | 'sample': { 77 | 'local': 'data/stackoverflow/Questions.Stratified.All.{}.parquet', 78 | 's3': 's3://stackoverflow-events/08-05-2019/Questions.Stratified.All.{}.parquet', 79 | }, 80 | 'tag_index': { 81 | 'local': 'data/stackoverflow/all_tag_index.{}.json', 82 | 's3': '08-05-2019/tag_index.{}.json', 83 | }, 84 | 'index_tag': { 85 | 'local': 'data/stackoverflow/all_index_tag.{}.json', 86 | 's3': '08-05-2019/index_tag.{}.json', 87 | }, 88 | } 89 | 90 | # 91 | # Initialize Spark with dynamic allocation enabled to (hopefully) use less RAM 92 | # 93 | spark = SparkSession.builder\ 94 | .appName('Weakly Supervised Learning - Extract Questions') \ 95 | .config('spark.dynamicAllocation.enabled', True) \ 96 | .config('spark.shuffle.service.enabled', True) \ 97 | .getOrCreate() 98 | sc = spark.sparkContext 99 | 100 | 101 | # 102 | # Get answered questions and not their answers 103 | # 104 | posts = spark.read.parquet(PATHS['posts'][PATH_SET]) 105 | if DEBUG is True: 106 | print('Total posts count: {:,}'.format( 107 | posts.count() 108 | )) 109 | questions = posts.filter(posts._ParentId.isNull())\ 110 | .filter(posts._AnswerCount > 0)\ 111 | .filter(posts._Score > 1) 112 | if DEBUG is True: 113 | print('Total questions count: {:,}'.format(questions.count())) 114 | 115 | # Combine title with body 116 | questions = questions.select( 117 | F.col('_Id').alias('_PostId'), 118 | '_AcceptedAnswerId', 119 | F.concat( 120 | F.col("_Title"), 121 | F.lit(" "), 122 | F.col("_Body") 123 | ).alias('_Body'), 124 | '_Tags', 125 | '_AnswerCount', 126 | '_CommentCount', 127 | '_FavoriteCount', 128 | '_OwnerUserId', 129 | '_OwnerDisplayName', 130 | '_Score', 131 | '_ViewCount', 132 | ) 133 | questions.show() 134 | 135 | # Write all questions to a Parquet file, then trim fields 136 | questions\ 137 | .write.mode('overwrite')\ 138 | .parquet(PATHS['questions'][PATH_SET]) 139 | questions_df = spark.read.parquet(PATHS['questions'][PATH_SET]) 140 | 141 | # 142 | # Join User records from ch02/xml_to_parquet.py 143 | # 144 | users_df = spark.read.parquet(PATHS['users_parquet'][PATH_SET]) 145 | users_df = users_df.withColumn( 146 | '_UserId', 147 | F.col('_Id') 148 | ).drop('_Id') 149 | 150 | questions_users_df = questions_df.join( 151 | users_df, 152 | on=questions_df._OwnerUserId == users_df._UserId, 153 | how='left_outer' 154 | ) 155 | questions_users_df = questions_users_df.selectExpr( 156 | '_PostId', 157 | '_AcceptedAnswerId', 158 | '_Body', 159 | '_Tags', 160 | '_AnswerCount', 161 | '_CommentCount', 162 | '_FavoriteCount', 163 | '_OwnerUserId', 164 | '_OwnerDisplayName', 165 | '_Score', 166 | '_ViewCount', 167 | '_AboutMe AS _UserAboutMe', 168 | '_AccountId', 169 | '_UserId', 170 | '_DisplayName AS _UserDisplayName', 171 | '_DownVotes AS _UserDownVotes', 172 | '_Location AS _UserLocation', 173 | '_ProfileImageUrl', 174 | '_Reputation AS _UserReputation', 175 | '_UpVotes AS _UserUpVotes', 176 | '_Views AS _UserViews', 177 | '_WebsiteUrl AS _UserWebsiteUrl', 178 | ) 179 | questions_users_df.write.mode('overwrite').parquet(PATHS['questions_users'][PATH_SET]) 180 | questions_users_df = spark.read.parquet(PATHS['questions_users'][PATH_SET]) 181 | if DEBUG is True: 182 | questions_users_df.show() 183 | 184 | # Count the number of each tag 185 | all_tags = questions_users_df.rdd.flatMap(lambda x: re.sub('[<>]', ' ', x['_Tags']).split()) 186 | 187 | # Count the instances of each tag 188 | tag_counts_df = all_tags\ 189 | .groupBy(lambda x: x)\ 190 | .map(lambda x: Row(tag=x[0], total=len(x[1])))\ 191 | .toDF()\ 192 | .select('tag', 'total').orderBy(['total'], ascending=False) 193 | tag_counts_df.write.mode('overwrite').parquet(PATHS['tag_counts'][PATH_SET]) 194 | tag_counts_df = spark.read.parquet(PATHS['tag_counts'][PATH_SET]) 195 | 196 | if DEBUG is True: 197 | tag_counts_df.show(100) 198 | 199 | # Create a local dict of tag counts 200 | local_tag_counts = tag_counts_df.rdd.collect() 201 | tag_counts = {x.tag: x.total for x in local_tag_counts} 202 | 203 | # Use tags with at least 50 instances 204 | TAG_LIMIT = 50 205 | 206 | # Count the good tags 207 | remaining_tags_df = tag_counts_df.filter(tag_counts_df.total > TAG_LIMIT) 208 | tag_total = remaining_tags_df.count() 209 | print(f'\n\nNumber of tags with > {TAG_LIMIT:,} instances: {tag_total:,}') 210 | valid_tags = remaining_tags_df.rdd.map(lambda x: x['tag']).collect() 211 | 212 | # Create forward and backward indexes for good/bad tags 213 | tag_index, index_tag, enumerated_labels = get_indexes(remaining_tags_df) 214 | 215 | # Turn text of body and tags into lists of words 216 | def tag_list_record(x, valid_tags): 217 | d = x.asDict() 218 | 219 | body = extract_text_plain(d['_Body']) 220 | code = extract_code_plain(x['_Body']) 221 | tags = re.sub('[<>]', ' ', x['_Tags']).split() 222 | valid_tags = [y for y in tags if y in valid_tags] 223 | 224 | 225 | d['_Body'] = body 226 | d['_Code'] = code 227 | d['_Tags'] = valid_tags 228 | d['_Label'] = 0 229 | 230 | return Row(**d) 231 | 232 | questions_lists = questions_users_df.rdd.map(lambda x: tag_list_record(x, valid_tags)) 233 | 234 | filtered_lists = questions_lists\ 235 | .filter(lambda x: bool(set(x._Tags) & set(valid_tags))) 236 | 237 | # Create a DataFrame to persist this progress 238 | tag_list_schema = T.StructType([ 239 | T.StructField('_PostId', T.IntegerType(), True), 240 | T.StructField('_AcceptedAnswerId', T.IntegerType(), True), 241 | T.StructField('_Body', T.StringType(), True), 242 | T.StructField('_Code', T.StringType(), True), 243 | T.StructField( 244 | "_Tags", 245 | T.ArrayType( 246 | T.StringType() 247 | ) 248 | ), 249 | T.StructField('_Label', T.IntegerType(), True), 250 | T.StructField('_AnswerCount', T.IntegerType(), True), 251 | T.StructField('_CommentCount', T.IntegerType(), True), 252 | T.StructField('_FavoriteCount', T.IntegerType(), True), 253 | T.StructField('_OwnerUserId', T.IntegerType(), True), 254 | T.StructField('_OwnerDisplayName', T.StringType(), True), 255 | T.StructField('_Score', T.IntegerType(), True), 256 | T.StructField('_ViewCount', T.IntegerType(), True), 257 | T.StructField('_UserAboutMe', T.StringType(), True), 258 | T.StructField('_AccountId',T.IntegerType(), True), 259 | T.StructField('_UserId', T.IntegerType(), True), 260 | T.StructField('_UserDisplayName', T.StringType(),True), 261 | T.StructField('_UserDownVotes', T.IntegerType(), True), 262 | T.StructField('_UserLocation', T.StringType(), True), 263 | T.StructField('_ProfileImageUrl', T.StringType(), True), 264 | T.StructField('_UserReputation', T.IntegerType() ,True), 265 | T.StructField('_UserUpVotes', T.IntegerType(), True), 266 | T.StructField('_UserViews', T.IntegerType(), True), 267 | T.StructField('_UserWebsiteUrl', T.StringType(), True), 268 | ]) 269 | 270 | questions_tags_df = spark.createDataFrame( 271 | filtered_lists, 272 | tag_list_schema 273 | ) 274 | 275 | questions_tags_df.write.mode('overwrite').parquet(PATHS['questions_tags'][PATH_SET].format(TAG_LIMIT)) 276 | questions_tags_df = spark.read.parquet(PATHS['questions_tags'][PATH_SET].format(TAG_LIMIT)) 277 | questions_tags_df.show() 278 | 279 | # # Emit one record per tag 280 | # def emit_tag_records(x, tag_index): 281 | # d = x.asDict() 282 | 283 | # for tag in d['_Tags']: 284 | 285 | # n = d.copy() 286 | # n['_LabelIndex'] = tag_index[tag] 287 | # n['_LabelString'] = tag 288 | # n['_LabelValue'] = 1 289 | # del n['_Tags'] 290 | 291 | # yield(Row(**n)) 292 | 293 | # per_tag_questions = questions_tags_df.rdd.flatMap(lambda x: emit_tag_records(x, tag_index)) 294 | 295 | # # Create a DataFrame out of the one-hot encoded RDD 296 | # per_tag_schema = T.StructType([ 297 | # T.StructField('_PostId', T.IntegerType(), True), 298 | # T.StructField('_AcceptedAnswerId', T.IntegerType(), True), 299 | # T.StructField('_Body', T.StringType(), True), 300 | # T.StructField('_Code', T.StringType(), True), 301 | # T.StructField('_LabelIndex', T.IntegerType(), True), 302 | # T.StructField('_LabelString', T.StringType(), True), 303 | # T.StructField('_LabelValue', T.IntegerType(), True), 304 | # T.StructField('_AnswerCount', T.IntegerType(), True), 305 | # T.StructField('_CommentCount', T.IntegerType(), True), 306 | # T.StructField('_FavoriteCount', T.IntegerType(), True), 307 | # T.StructField('_OwnerUserId', T.IntegerType(), True), 308 | # T.StructField('_OwnerDisplayName', T.StringType(), True), 309 | # T.StructField('_Score', T.IntegerType(), True), 310 | # T.StructField('_ViewCount', T.IntegerType(), True), 311 | # T.StructField('_UserAboutMe', T.StringType(), True), 312 | # T.StructField('_AccountId',T.IntegerType(), True), 313 | # T.StructField('_UserId', T.IntegerType(), True), 314 | # T.StructField('_UserDisplayName', T.StringType(),True), 315 | # T.StructField('_UserDownVotes', T.IntegerType(), True), 316 | # T.StructField('_UserLocation', T.StringType(), True), 317 | # T.StructField('_ProfileImageUrl', T.StringType(), True), 318 | # T.StructField('_UserReputation', T.IntegerType() ,True), 319 | # T.StructField('_UserUpVotes', T.IntegerType(), True), 320 | # T.StructField('_UserViews', T.IntegerType(), True), 321 | # T.StructField('_UserWebsiteUrl', T.StringType(), True), 322 | # ]) 323 | 324 | # per_tag_df = spark.createDataFrame( 325 | # per_tag_questions, 326 | # per_tag_schema 327 | # ) 328 | 329 | # # Save as Parquet format, partitioned by the label index 330 | # per_tag_df.write.mode('overwrite').parquet( 331 | # PATHS['per_tag'][PATH_SET].format(TAG_LIMIT), 332 | # partitionBy=['_LabelIndex'] 333 | # ) 334 | 335 | # per_tag_df = spark.read.parquet( 336 | # PATHS['per_tag'][PATH_SET].format(TAG_LIMIT) 337 | # ) 338 | # per_tag_df.registerTempTable('per_tag') 339 | 340 | # # # 341 | # # # 1) Use GROUP BY to get sample ratios 342 | # # # 343 | # # from datetime import datetime 344 | 345 | # # # Get the counts for tags all at once 346 | # # total_records_df = spark.sql('SELECT COUNT(*) AS total FROM per_tag') 347 | # # total_records = total_records_df.first().total 348 | 349 | # # query = f""" 350 | # # SELECT 351 | # # _LabelIndex, 352 | # # COUNT(*) AS total, 353 | # # COUNT(*)/{total_records} AS sample_ratio 354 | # # FROM per_tag 355 | # # GROUP BY _LabelIndex 356 | # # """ 357 | # # sample_ratios_df = spark.sql(query) 358 | # # sample_ratios_df.write.mode('overwrite').parquet( 359 | # # PATHS['sample_ratios'][PATH_SET].format(TAG_LIMIT) 360 | # # ) 361 | # # sample_ratios = sample_ratios_df.rdd.map(lambda x: x.asDict()).collect() 362 | # # sample_ratios_d = {x['_LabelIndex'] : x for x in sample_ratios} 363 | 364 | # start = datetime.now() 365 | 366 | # def sample_group(x, sample_ratios_d): 367 | # sample_n = 50 368 | # rs = random.Random() 369 | # yield rs.sample(list(x), sample_n) 370 | 371 | 372 | # groupable = stratified_down_sample = per_tag_df.rdd \ 373 | # .map(lambda x: (x._LabelIndex, x)) 374 | 375 | # grouped = groupable.groupByKey() 376 | 377 | # .flatMap( 378 | # lambda x: sample_group(x[1] if len(x) > 0 else [], sample_ratios_d) 379 | # ) \ 380 | # .flatMapValues(lambda x: x[1]) 381 | 382 | # stratified_df = spark.createDataFrame( 383 | # stratified_down_sample, 384 | # per_tag_schema 385 | # ) 386 | # stratified_df.write.mode('overwrite').parquet( 387 | # PATHS['sample'][PATH_SET].format(TAG_LIMIT), 388 | # partitionBy=['_LabelIndex'] 389 | # ) 390 | 391 | # end = datetime.now() 392 | # speed = end - start 393 | # print(speed) 394 | 395 | # # diff = speed_3 - speed_2 396 | # # print(diff) 397 | 398 | 399 | # # # Write out a stratify_limit sized stratified sample for each tag 400 | # # for i in range(0, 10):#tag_total): 401 | # # print(f'\n\nProcessing tag {i:,} of {tag_total:,} total tags\n\n') 402 | 403 | # # one_label_df = one_hot_df.rdd.map( 404 | # # lambda x: Row( 405 | # # _Body=x._Body, 406 | # # _Code=x._Code, 407 | # # _Label=x._Tags[i] 408 | # # ) 409 | # # ).toDF() 410 | 411 | # # one_label_df = one_hot_df.select( 412 | # # '_Body', 413 | # # '_Code', 414 | # # one_hot_df['_Tags'].getItem(i).alias('_Label') 415 | # # ) 416 | 417 | # # # Select records with a positive value for this tag 418 | # # positive_examples = one_label_df.filter(one_label_df._Label == 1) 419 | # # negative_examples = one_label_df.filter(one_label_df._Label == 0) 420 | 421 | # # # Sample the positive examples to equal the stratify limit 422 | # # positive_count = positive_examples.count() 423 | # # ratio = min(1.0, SAMPLE_LIMIT / positive_count) 424 | # # sample_ratio = max(0.0, ratio) 425 | # # positive_examples_sample = positive_examples.sample(False, sample_ratio, seed=1337) 426 | 427 | # # # Now get an equal number of negative examples 428 | # # positive_count = positive_examples_sample.count() 429 | # # negative_count = negative_examples.count() 430 | # # ratio = min(1.0, positive_count / negative_count) 431 | # # sample_ratio = max(0.0, ratio) 432 | # # negative_examples_sample = negative_examples.sample(False, sample_ratio, seed=1337) 433 | 434 | # # final_examples_df = positive_examples_sample.union(negative_examples_sample) 435 | 436 | # # if DEBUG is True: 437 | # # final_examples_df.show() 438 | 439 | # # # Write the record out as JSON under a directory we will then read in its enrirety 440 | # # final_examples_df.write.mode('overwrite').json(PATHS['output_jsonl'][PATH_SET].format(TAG_LIMIT, i)) 441 | 442 | # # # Free RAM explicitly each loop 443 | # # del final_examples_df 444 | # # gc.collect() 445 | 446 | -------------------------------------------------------------------------------- /conda.env.yaml: -------------------------------------------------------------------------------- 1 | name: weak 2 | channels: 3 | - defaults 4 | - conda-forge 5 | dependencies: 6 | - _libgcc_mutex=0.1=main 7 | - _tflow_select=2.1.0=gpu 8 | - absl-py=0.8.1=py37_0 9 | - appdirs=1.4.3=py37h28b3542_0 10 | - arrow-cpp=0.14.1=py37h6b969ab_1 11 | - asn1crypto=1.2.0=py37_0 12 | - astor=0.8.0=py37_0 13 | - attrs=19.3.0=py_0 14 | - backcall=0.1.0=py37_0 15 | - beautifulsoup4=4.8.1=py37_0 16 | - black=19.3b0=py_0 17 | - blas=1.0=mkl 18 | - bleach=3.1.0=py37_0 19 | - blessings=1.7=py37_1000 20 | - bokeh=1.4.0=py37_0 21 | - boost-cpp=1.70.0=ha2d47e9_1 22 | - boto=2.49.0=py37_0 23 | - boto3=1.9.234=py_0 24 | - botocore=1.12.234=py_0 25 | - brotli=1.0.7=he6710b0_0 26 | - bz2file=0.98=py37_1 27 | - bzip2=1.0.8=h7b6447c_0 28 | - c-ares=1.15.0=h7b6447c_1001 29 | - ca-certificates=2019.10.16=0 30 | - certifi=2019.9.11=py37_0 31 | - cffi=1.13.2=py37h2e261b9_0 32 | - chardet=3.0.4=py37_1003 33 | - click=7.0=py37_0 34 | - cloudpickle=1.2.2=py_0 35 | - cryptography=2.8=py37h1ba5d50_0 36 | - cudatoolkit=10.0.130=0 37 | - cudnn=7.6.4=cuda10.0_0 38 | - cupti=10.0.130=0 39 | - cycler=0.10.0=py37_0 40 | - cymem=2.0.2=py37hfd86e86_0 41 | - cython-blis=0.4.1=py37h516909a_0 42 | - cytoolz=0.10.1=py37h7b6447c_0 43 | - dask=2.8.0=py_1 44 | - dask-core=2.8.0=py_0 45 | - dask-glm=0.2.0=py37_0 46 | - dask-ml=1.1.1=py_0 47 | - dbus=1.13.12=h746ee38_0 48 | - decorator=4.4.1=py_0 49 | - defusedxml=0.6.0=py_0 50 | - dill=0.3.1.1=py37_0 51 | - distributed=2.8.0=py_1 52 | - docutils=0.15.2=py37_0 53 | - double-conversion=3.1.5=he6710b0_1 54 | - entrypoints=0.3=py37_0 55 | - expat=2.2.6=he6710b0_0 56 | - filelock=3.0.12=py_0 57 | - fontconfig=2.13.0=h9420a91_0 58 | - freetype=2.9.1=h8a8886c_1 59 | - frozendict=1.2=py_2 60 | - fsspec=0.6.0=py_0 61 | - gast=0.2.2=py37_0 62 | - gensim=3.8.0=py37h962f231_0 63 | - gflags=2.2.2=he6710b0_0 64 | - glib=2.63.1=h5a9c865_0 65 | - glog=0.4.0=he6710b0_0 66 | - gmp=6.1.2=h6c8ec71_1 67 | - google-pasta=0.1.8=py_0 68 | - gpustat=0.6.0=py_0 69 | - grpc-cpp=1.25.0=h18db393_0 70 | - grpcio=1.16.1=py37hf8bcb03_1 71 | - gst-plugins-base=1.14.0=hbbd80ab_1 72 | - gstreamer=1.14.0=hb453b48_1 73 | - h5py=2.9.0=py37h7918eee_0 74 | - hdf5=1.10.4=hb1b8bf9_0 75 | - heapdict=1.0.1=py_0 76 | - icu=58.2=h9c2bf20_1 77 | - idna=2.8=py37_0 78 | - importlib_metadata=0.23=py37_0 79 | - intel-openmp=2019.4=243 80 | - ipykernel=5.1.3=py37h39e3cac_0 81 | - ipython=7.9.0=py37h39e3cac_0 82 | - ipython_genutils=0.2.0=py37_0 83 | - ipywidgets=7.5.1=py_0 84 | - iso8601=0.1.12=py37_1 85 | - jedi=0.15.1=py37_0 86 | - jinja2=2.10.3=py_0 87 | - jmespath=0.9.4=py_0 88 | - joblib=0.14.0=py_0 89 | - jpeg=9b=h024ee3a_2 90 | - jsonschema=3.2.0=py37_0 91 | - jupyter=1.0.0=py37_7 92 | - jupyter_client=5.3.4=py37_0 93 | - jupyter_console=6.0.0=py37_0 94 | - jupyter_core=4.6.1=py37_0 95 | - keras-applications=1.0.8=py_0 96 | - keras-preprocessing=1.1.0=py_1 97 | - kiwisolver=1.1.0=py37he6710b0_0 98 | - libedit=3.1.20181209=hc058e9b_0 99 | - libevent=2.1.10=h72c5cf5_0 100 | - libffi=3.2.1=hd88cf55_4 101 | - libgcc-ng=9.1.0=hdf63c60_0 102 | - libgfortran-ng=7.3.0=hdf63c60_0 103 | - libpng=1.6.37=hbc83047_0 104 | - libprotobuf=3.8.0=hd408876_0 105 | - libsodium=1.0.16=h1bed415_0 106 | - libstdcxx-ng=9.1.0=hdf63c60_0 107 | - libtiff=4.0.9=he85c1e1_2 108 | - libuuid=1.0.3=h1bed415_2 109 | - libxcb=1.13=h1bed415_1 110 | - libxml2=2.9.9=hea5a465_1 111 | - libxslt=1.1.33=h7d1a2b0_0 112 | - llvmlite=0.30.0=py37hd408876_0 113 | - locket=0.2.0=py37_1 114 | - lxml=4.4.1=py37hefd8a0e_0 115 | - lz4-c=1.8.3=he1b5a44_1001 116 | - markdown=3.1.1=py37_0 117 | - markupsafe=1.1.1=py37h7b6447c_0 118 | - matplotlib=3.1.1=py37h5429711_0 119 | - mistune=0.8.4=py37h7b6447c_0 120 | - mkl=2019.4=243 121 | - mkl-service=2.3.0=py37he904b0f_0 122 | - mkl_fft=1.0.15=py37ha843d7b_0 123 | - mkl_random=1.1.0=py37hd6b4f25_0 124 | - more-itertools=7.2.0=py37_0 125 | - msgpack-python=0.6.1=py37hfd86e86_1 126 | - multipledispatch=0.6.0=py37_0 127 | - murmurhash=1.0.2=py37he6710b0_0 128 | - nbconvert=5.6.1=py37_0 129 | - nbformat=4.4.0=py37_0 130 | - nbstripout=0.3.6=py_0 131 | - ncurses=6.1=he6710b0_1 132 | - nltk=3.4.5=py37_0 133 | - notebook=6.0.2=py37_0 134 | - numba=0.46.0=py37h962f231_0 135 | - numpy=1.17.3=py37hd14ec0e_0 136 | - numpy-base=1.17.3=py37hde5b4d6_0 137 | - nvidia-ml=7.352.0=py_0 138 | - olefile=0.46=py37_0 139 | - openssl=1.1.1d=h7b6447c_3 140 | - opt_einsum=3.1.0=py_0 141 | - packaging=19.2=py_0 142 | - pandas=0.25.3=py37he6710b0_0 143 | - pandoc=2.2.3.2=0 144 | - pandocfilters=1.4.2=py37_1 145 | - parquet-cpp=1.5.1=2 146 | - parso=0.5.1=py_0 147 | - partd=1.0.0=py_0 148 | - patsy=0.5.1=py37_0 149 | - pcre=8.43=he6710b0_0 150 | - pexpect=4.7.0=py37_0 151 | - pickleshare=0.7.5=py37_0 152 | - pillow=5.4.1=py37h34e0f95_0 153 | - pip=19.3.1=py37_0 154 | - plac=0.9.6=py37_0 155 | - pluggy=0.13.0=py37_0 156 | - preshed=3.0.2=py37he1b5a44_1 157 | - prometheus_client=0.7.1=py_0 158 | - prompt_toolkit=2.0.10=py_0 159 | - protobuf=3.8.0=py37he6710b0_0 160 | - psutil=5.6.5=py37h7b6447c_0 161 | - ptyprocess=0.6.0=py37_0 162 | - py=1.8.0=py37_0 163 | - py4j=0.10.7=py37_0 164 | - pyarrow=0.14.1=py37h8b68381_2 165 | - pycparser=2.19=py37_0 166 | - pygments=2.4.2=py_0 167 | - pyopenssl=19.1.0=py37_0 168 | - pyparsing=2.4.5=py_0 169 | - pyqt=5.9.2=py37h05f1152_2 170 | - pyrsistent=0.15.5=py37h7b6447c_0 171 | - pysocks=1.7.1=py37_0 172 | - pyspark=2.4.4=py_0 173 | - python=3.7.4=h265db76_1 174 | - python-dateutil=2.8.1=py_0 175 | - pytz=2019.3=py_0 176 | - pyyaml=5.1.2=py37h7b6447c_0 177 | - pyzmq=18.1.0=py37he6710b0_0 178 | - qt=5.9.7=h5867ecd_1 179 | - qtconsole=4.6.0=py_0 180 | - re2=2019.08.01=he6710b0_0 181 | - readline=7.0=h7b6447c_5 182 | - requests=2.22.0=py37_0 183 | - s3fs=0.4.0=py_0 184 | - s3transfer=0.2.1=py37_0 185 | - scikit-learn=0.21.3=py37hd81dba3_0 186 | - scipy=1.3.1=py37h7c811a0_0 187 | - seaborn=0.9.0=pyh91ea838_1 188 | - send2trash=1.5.0=py37_0 189 | - setuptools=41.6.0=py37_0 190 | - sip=4.19.8=py37hf484d3e_0 191 | - six=1.13.0=py37_0 192 | - smart_open=1.8.4=py_0 193 | - snappy=1.1.7=hbae5bb6_3 194 | - sortedcontainers=2.1.0=py37_0 195 | - soupsieve=1.9.5=py37_0 196 | - spacy=2.2.2=py37hc9558a2_0 197 | - sqlite=3.30.1=h7b6447c_0 198 | - srsly=0.2.0=py37he1b5a44_0 199 | - statsmodels=0.10.1=py37hdd07704_0 200 | - tbb=2019.8=hfd86e86_0 201 | - tblib=1.5.0=py_0 202 | - tensorboard=2.0.0=pyhb230dea_0 203 | - tensorflow=2.0.0=gpu_py37h768510d_0 204 | - tensorflow-base=2.0.0=gpu_py37h0ec5d1f_0 205 | - tensorflow-estimator=2.0.0=pyh2649769_0 206 | - tensorflow-gpu=2.0.0=h0d30ee6_0 207 | - tensorflow-hub=0.7.0=pyhe6710b0_0 208 | - termcolor=1.1.0=py37_1 209 | - terminado=0.8.3=py37_0 210 | - testpath=0.4.4=py_0 211 | - textblob=0.15.3=py_0 212 | - texttable=1.6.2=py_0 213 | - thinc=7.3.0=py37hc9558a2_0 214 | - thrift-cpp=0.12.0=hf3afdfd_1004 215 | - tk=8.6.8=hbc83047_0 216 | - toml=0.10.0=py37h28b3542_0 217 | - toolz=0.10.0=py_0 218 | - tornado=6.0.3=py37h7b6447c_0 219 | - tox=3.14.1=py_1 220 | - tqdm=4.38.0=py_0 221 | - traitlets=4.3.3=py37_0 222 | - uriparser=0.9.3=he1b5a44_1 223 | - urllib3=1.24.2=py37_0 224 | - virtualenv=16.7.5=py_0 225 | - wasabi=0.4.0=py_0 226 | - wcwidth=0.1.7=py37_0 227 | - webencodings=0.5.1=py37_1 228 | - werkzeug=0.16.0=py_0 229 | - wheel=0.33.6=py37_0 230 | - widgetsnbextension=3.5.1=py37_0 231 | - wrapt=1.11.2=py37h7b6447c_0 232 | - xz=5.2.4=h14c3975_4 233 | - yaml=0.1.7=had09818_2 234 | - zeromq=4.3.1=he6710b0_3 235 | - zict=1.0.0=py_0 236 | - zipp=0.6.0=py_0 237 | - zlib=1.2.11=h7b6447c_3 238 | - zstd=1.4.0=h3b9ef0a_0 239 | - pip: 240 | - argh==0.26.2 241 | - catalogue==0.0.8 242 | - configparser==4.0.2 243 | - cupy-cuda100==6.5.0 244 | - docker-pycreds==0.4.0 245 | - en-core-web-lg==2.2.5 246 | - fastrlock==0.4 247 | - gitdb2==2.0.6 248 | - gitpython==3.0.5 249 | - gql==0.1.0 250 | - graphql-core==2.2.1 251 | - munkres==1.1.2 252 | - networkx==2.3 253 | - pathtools==0.1.2 254 | - pip-tools==4.2.0 255 | - promise==2.2.1 256 | - rx==1.6.1 257 | - sentry-sdk==0.13.2 258 | - shortuuid==0.5.0 259 | - smmap2==2.0.5 260 | - subprocess32==3.5.4 261 | - tensorboardx==1.9 262 | - torch==1.3.1 263 | - wandb==0.8.15 264 | - watchdog==0.9.0 265 | -------------------------------------------------------------------------------- /conda.pip.requirements.txt: -------------------------------------------------------------------------------- 1 | cupy-cuda100 2 | pip-tools 3 | -e git+git://github.com/snorkel-team/snorkel@master#egg=snorkel 4 | spacy[cuda100] @ git+git://github.com/mmaybeno/spaCy@agnostic-vocab-array-fix#egg=snorkel[cuda100] 5 | torch 6 | wandb 7 | -------------------------------------------------------------------------------- /conda.requirements.txt: -------------------------------------------------------------------------------- 1 | beautifulsoup4 2 | boto3 3 | cython 4 | # bert-for-tf2 5 | dask 6 | dill 7 | # fast-bert 8 | frozendict 9 | gensim 10 | gpustat 11 | # guesslang 12 | ipython 13 | iso8601 14 | jupyter 15 | lxml 16 | nltk 17 | numpy>=1.16.0 18 | pandas 19 | pyarrow==0.14.1 20 | pyspark 21 | requests 22 | s3fs 23 | seaborn 24 | scikit-learn 25 | spacy[cuda100]==2.2.2 26 | tensorflow-gpu==2.1.0 27 | tensorflow-hub 28 | textblob 29 | textdistance 30 | texttable 31 | tqdm 32 | 33 | -------------------------------------------------------------------------------- /data/.exists: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rjurney/weakly_supervised_learning_code/1b1c2b1336384b0d4ca47846d52b459d61468197/data/.exists -------------------------------------------------------------------------------- /data/amazon_github_repos.json.bz2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rjurney/weakly_supervised_learning_code/1b1c2b1336384b0d4ca47846d52b459d61468197/data/amazon_github_repos.json.bz2 -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: "3.7" 2 | services: 3 | jupyter: 4 | build: . 5 | image: weakly_supervised_learning 6 | container_name: weakly_supervised_learning 7 | labels: 8 | description: Weakly Supervised Learning (O'Reilly, 2020) by Russell Jurney 9 | name: weakly_supervised_learning 10 | volumes: 11 | - ".:/weakly_supervised_learning" 12 | ports: 13 | - "8888:8888" 14 | -------------------------------------------------------------------------------- /download.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # 4 | # I download the datasets for this book 5 | # 6 | 7 | # Get the Stack Overflow posts 8 | nohup wget https://archive.org/download/stackexchange/stackoverflow.com-Posts.7z & 9 | 10 | # Get the Stack Overflow users 11 | nohup wget https://archive.org/download/stackexchange/stackoverflow.com-Users.7z & 12 | -------------------------------------------------------------------------------- /lib/utils.py: -------------------------------------------------------------------------------- 1 | # Utilities for the book's notebooks 2 | 3 | import re 4 | 5 | import nltk 6 | import numpy as np 7 | import pyspark.sql.functions as F 8 | import pyspark.sql.types as T 9 | 10 | from bs4 import BeautifulSoup 11 | from nltk.corpus import stopwords 12 | from nltk.tokenize import RegexpTokenizer 13 | from nltk.tokenize.punkt import PunktSentenceTokenizer 14 | from pyspark.sql import Row 15 | from snorkel.analysis import get_label_buckets 16 | 17 | 18 | # In order to tokenize questions and remove stopwords 19 | nltk.download('punkt') 20 | nltk.download('stopwords') 21 | stop_words = set(stopwords.words('english')) 22 | tokenizer = RegexpTokenizer(r'\w+') 23 | 24 | 25 | def fix_metric_name(name): 26 | """Remove the trailing _NN, ex. precision_86""" 27 | if name[-1].isdigit(): 28 | repeat_name = '_'.join(name.split('_')[:-1]) 29 | else: 30 | repeat_name = name 31 | return repeat_name 32 | 33 | 34 | def fix_value(val): 35 | """Convert from numpy to float""" 36 | return val.item() if isinstance(val, np.float32) else val 37 | 38 | 39 | def fix_metric(name, val): 40 | """Fix Tensorflow/Keras metrics by removing any training _NN and concert numpy.float to python float""" 41 | repeat_name = fix_metric_name(name) 42 | py_val = fix_value(val) 43 | return repeat_name, py_val 44 | 45 | 46 | def get_indexes(df): 47 | """Create indexes for each multilabel tag""" 48 | enumerated_labels = [ 49 | z for z in enumerate( 50 | sorted( 51 | df.rdd 52 | .groupBy(lambda x: 1) 53 | .flatMap(lambda x: [y.tag for y in x[1]]) 54 | .collect() 55 | ) 56 | ) 57 | ] 58 | tag_index = {x: i for i, x in enumerated_labels} 59 | index_tag = {i: x for i, x in enumerated_labels} 60 | return tag_index, index_tag, enumerated_labels 61 | 62 | 63 | def extract_text(x, max_len=200, pad_token='__PAD__', stop_words=stop_words): 64 | """Extract, remove stopwords and tokenize non-code text from posts (questions/answers)""" 65 | doc = BeautifulSoup(x, 'lxml') 66 | codes = doc.find_all('code') 67 | [code.extract() if code else None for code in codes] 68 | text = re.sub(r'http\S+', ' ', doc.text) 69 | tokens = [x for x in tokenizer.tokenize(text) if x not in stop_words] 70 | 71 | padded_tokens = [] 72 | if pad_token: 73 | padded_tokens = [tokens[i] if len(tokens) > i else pad_token for i in range(0, max_len)] 74 | else: 75 | padded_tokens = tokens 76 | return padded_tokens 77 | 78 | 79 | def extract_text_plain(x): 80 | """Extract non-code text from posts (questions/answers)""" 81 | doc = BeautifulSoup(x, 'lxml') 82 | codes = doc.find_all('code') 83 | [code.extract() if code else None for code in codes] 84 | text = re.sub(r'http\S+', ' ', doc.text) 85 | return text 86 | 87 | 88 | def extract_code_plain(x): 89 | """Extract code text from posts (questions/answers)""" 90 | doc = BeautifulSoup(x, 'lxml') 91 | codes = doc.find_all('code') 92 | text = '\n'.join([c.text for c in codes]) 93 | return text 94 | 95 | 96 | def extract_bert_format(x): 97 | """Extract text in BERT format""" 98 | 99 | # Parse the sentences from the document 100 | sentence_tokenizer = PunktSentenceTokenizer() 101 | sentences = sentence_tokenizer.tokenize(x) 102 | 103 | # Write each sentence exactly as it appared to one line each 104 | for sentence in sentences: 105 | yield(sentence.encode('unicode-escape').decode().replace('\\\\', '\\')) 106 | 107 | # Add the final document separator 108 | yield('') 109 | 110 | 111 | def one_hot_encode(tag_list, enumerated_labels, index_tag): 112 | """PySpark can't one-hot-encode multilabel data, so we do it ourselves.""" 113 | 114 | one_hot_row = [] 115 | for i, label in enumerated_labels: 116 | if index_tag[i] in tag_list: 117 | one_hot_row.append(1) 118 | else: 119 | one_hot_row.append(0) 120 | assert(len(one_hot_row) == len(enumerated_labels)) 121 | return one_hot_row 122 | 123 | 124 | def create_labeled_schema(one_row): 125 | """Create a schema naming all one-hot encoded fields label_{}""" 126 | schema_list = [ 127 | T.StructField("_Body", T.StringType()), 128 | T.StructField("_Code", T.StringType()), 129 | ] 130 | for i, val in list(enumerate(one_row._Tags)): 131 | schema_list.append( 132 | T.StructField( 133 | f'label_{i}', 134 | T.IntegerType() 135 | ) 136 | ) 137 | return T.StructType(schema_list) 138 | 139 | 140 | def create_label_row_columns(x): 141 | """Create a dict keyed with dynamic args to use to create a Row for this record""" 142 | args = {f'label_{i}': val for i, val in list(enumerate(x._Tags))} 143 | args['_Body'] = x._Body 144 | args['_Code'] = x._Code 145 | return Row(**args) 146 | 147 | 148 | def get_mistakes(df, probs_test, buckets, labels, label_names): 149 | """Take DataFrame and pair of actual/predicted labels/names and return a DataFrame showing those records.""" 150 | df_fn = df.iloc[buckets[labels]] 151 | df_fn['probability'] = probs_test[buckets[labels], 1] 152 | df_fn['true label'] = label_names[0] 153 | df_fn['predicted label'] = label_names[1] 154 | return df_fn 155 | 156 | 157 | def mistakes_df(df, label_model, L_test, y_test): 158 | """Compute a DataFrame of all the mistakes we've seen.""" 159 | out_dfs = [] 160 | 161 | probs_test = label_model.predict_proba(L=L_test) 162 | preds_test = probs_test >= 0.5 163 | 164 | buckets = get_label_buckets( 165 | y_test, 166 | L_test[:, 1] 167 | ) 168 | print(buckets) 169 | 170 | for (actual, predicted) in buckets.keys(): 171 | 172 | # Only shot mistakes that we actually voted on 173 | if actual != predicted: 174 | 175 | actual_name = number_to_name_dict[actual] 176 | predicted_name = number_to_name_dict[predicted] 177 | 178 | out_dfs.append( 179 | get_mistakes( 180 | df, 181 | probs_test, 182 | buckets=buckets, 183 | labels=(actual, predicted), 184 | label_names=(actual_name, predicted_name) 185 | ) 186 | ) 187 | 188 | if len(out_dfs) > 1: 189 | return out_dfs[0].append( 190 | out_dfs[1:] 191 | ) 192 | else: 193 | return out_dfs[0] 194 | -------------------------------------------------------------------------------- /paths.json: -------------------------------------------------------------------------------- 1 | { 2 | "s3_bucket": "stackoverflow-events", 3 | "posts_xml": { 4 | "local": "data/stackoverflow/Posts.xml.bz2", 5 | "s3": "s3://stackoverflow-events/2019-12-02/Posts.xml.bz2" 6 | }, 7 | "posts": { 8 | "local": "data/stackoverflow/Posts.df.parquet", 9 | "s3": "s3://stackoverflow-events/2019-12-02/Posts.df.parquet" 10 | }, 11 | "users_xml": { 12 | "local": "data/stackoverflow/Users.xml.bz2", 13 | "s3": "s3://stackoverflow-events/2019-12-02/Users.xml.bz2" 14 | }, 15 | "users": { 16 | "local": "data/stackoverflow/Users.df.parquet", 17 | "s3": "s3://stackoverflow-events/2019-12-02/Users.df.parquet" 18 | }, 19 | "tags_xml": { 20 | "local": "data/stackoverflow/Tags.xml.bz2", 21 | "s3": "s3://stackoverflow-events/2019-12-02/Tags.xml.bz2" 22 | }, 23 | "tags": { 24 | "local": "data/stackoverflow/Tags.df.parquet", 25 | "s3": "s3://stackoverflow-events/2019-12-02/Tags.df.parquet" 26 | }, 27 | "badges_xml": { 28 | "local": "data/stackoverflow/Badges.xml.bz2", 29 | "s3": "s3://stackoverflow-events/2019-12-02/Badges.xml.bz2" 30 | }, 31 | "badges": { 32 | "local": "data/stackoverflow/Badges.df.parquet", 33 | "s3": "s3://stackoverflow-events/2019-12-02/Badges.df.parquet" 34 | }, 35 | "comments_xml": { 36 | "local": "data/stackoverflow/Comments.xml.bz2", 37 | "s3": "s3://stackoverflow-events/2019-12-02/Comments.xml.bz2" 38 | }, 39 | "comments": { 40 | "local": "data/stackoverflow/Comments.df.parquet", 41 | "s3": "s3://stackoverflow-events/2019-12-02/Comments.df.parquet" 42 | }, 43 | "postlinks_xml": { 44 | "local": "data/stackoverflow/PostLinks.xml.bz2", 45 | "s3": "s3://stackoverflow-events/2019-12-02/PostLinks.xml.bz2" 46 | }, 47 | "postlinks": { 48 | "local": "data/stackoverflow/PostLinks.df.parquet", 49 | "s3": "s3://stackoverflow-events/2019-12-02/PostLinks.df.parquet" 50 | }, 51 | "questions": { 52 | "local": "data/stackoverflow/Questions.Answered.parquet", 53 | "s3": "s3://stackoverflow-events/2019-12-02/Questions.Answered.parquet" 54 | }, 55 | "tag_counts": { 56 | "local": "data/stackoverflow/Questions.TagCounts.{}.parquet", 57 | "s3": "s3://stackoverflow-events/2019-12-02/Questions.TagCounts.{}.parquet" 58 | }, 59 | "questions_tags": { 60 | "local": "data/stackoverflow/Questions.Tags.{}.parquet", 61 | "s3": "s3://stackoverflow-events/2019-12-02/Questions.Tags.{}.parquet" 62 | }, 63 | "one_hot": { 64 | "local": "data/stackoverflow/Questions.Stratified.{}.parquet", 65 | "s3": "s3://stackoverflow-events/2019-12-02/Questions.Stratified.{}.parquet" 66 | }, 67 | "output_jsonl": { 68 | "local": "data/stackoverflow/Questions.Stratified.{}.{}.jsonl", 69 | "s3": "s3://stackoverflow-events/2019-12-02/Questions.Stratified.{}.{}.jsonl" 70 | }, 71 | "tag_index": { 72 | "local": "data/stackoverflow/tag_index.{}.json", 73 | "s3": "2019-12-02/tag_index.{}.json" 74 | }, 75 | "index_tag": { 76 | "local": "data/stackoverflow/index_tag.{}.json", 77 | "s3": "2019-12-02/index_tag.{}.json" 78 | }, 79 | "sorted_all_tags": { 80 | "local": "data/stackoverflow/sorted_all_tags.{}.json", 81 | "s3": "2019-12-02/sorted_all_tags.{}.json" 82 | }, 83 | "stratified_sample": { 84 | "local": "data/stackoverflow/Questions.Stratified.{}.*.jsonl", 85 | "s3": "s3://stackoverflow-events/2019-12-02/Questions.Stratified.{}.*.jsonl" 86 | }, 87 | "label_counts": { 88 | "local": "data/stackoverflow/label_counts.{}.json", 89 | "s3": "2019-12-02/label_counts.{}.json" 90 | }, 91 | "questions_final": { 92 | "local": "data/stackoverflow/Questions.Stratified.Final.{}.parquet", 93 | "s3": "s3://stackoverflow-events/2019-12-02/Questions.Stratified.Final.{}.parquet" 94 | }, 95 | "report": { 96 | "local": "data/stackoverflow/final_report.{}.json", 97 | "s3": "2019-12-02/final_report.{}.json" 98 | }, 99 | "bad_questions": { 100 | "local": "data/stackoverflow/Questions.Bad.{}.{}.parquet", 101 | "s3": "s3://stackoverflow-events/2019-12-02/Questions.Bad.{}.{}.parquet" 102 | } 103 | } -------------------------------------------------------------------------------- /requirements.dev.in: -------------------------------------------------------------------------------- 1 | nbstripout 2 | pip-tools 3 | -------------------------------------------------------------------------------- /requirements.in: -------------------------------------------------------------------------------- 1 | beautifulsoup4 2 | boto3 3 | cupy-cuda100 4 | dask 5 | dask-ml 6 | dill 7 | frozendict 8 | gensim 9 | gpustat 10 | # guesslang 11 | ipython 12 | iso8601 13 | jupyter 14 | lxml 15 | mlxtend 16 | munkres 17 | nltk 18 | numpy>=1.16.0 19 | pandas 20 | pip-tools 21 | pyarrow>=0.16.0 22 | pyspark 23 | requests 24 | s3fs 25 | seaborn 26 | # sentencepiece # Use Github and build/install https://github.com/google/sentencepiece 27 | scikit-learn 28 | shap 29 | snorkel 30 | spacy[cuda100]==2.2.2 31 | textblob 32 | textdistance 33 | texttable 34 | torch 35 | wandb 36 | -------------------------------------------------------------------------------- /settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.jediEnabled": true, 3 | "python.formatting.provider": "black", 4 | "python.linting.flake8Enabled": true, 5 | "python.linting.mypyEnabled": true, 6 | "python.linting.pydocstyleEnabled": true, 7 | "python.linting.pylintEnabled": false 8 | } 9 | --------------------------------------------------------------------------------