├── README.md ├── Global Vectors for Word Representation.ipynb └── Convolutional Neural Networks for Sentence Classification.ipynb /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch Implementations 2 | 3 | A collection of pytorch notebooks implementing deep learning papers. This notebooks were developed during studying these papers and helped me reinforced the concepts strongly. The goal is to maintain this repository and add more paper implementations. I hope that these notebooks help other practitioners as well. 4 | 5 | 6 | Name | Notebook 7 | --- | --- 8 | [A Neural Algorithm of Artistic Style](https://arxiv.org/pdf/1508.06576.pdf) | [![colab link](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jaygala24/pytorch-implementations/blob/master/A%20Neural%20Algorithm%20of%20Artistic%20Style.ipynb) 9 | [Show, Attend and Tell: Neural Image Caption Generation with Visual Attention](https://arxiv.org/pdf/1502.03044.pdf) | [![colab link](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jaygala24/pytorch-implementations/blob/master/Show%2C%20Attend%20and%20Tell.ipynb) 10 | [Efficient Estimation of Word Representations in Vector Space](https://arxiv.org/pdf/1301.3781.pdf) | [![colab link](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jaygala24/pytorch-implementations/blob/master/Efficient%20Estimation%20of%20Word%20Representations%20in%20Vector%20Space.ipynb) 11 | [GloVe: Global Vectors for Word Representation](https://nlp.stanford.edu/pubs/glove.pdf) | [![colab link](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jaygala24/pytorch-implementations/blob/master/Global%20Vectors%20for%20Word%20Representation.ipynb) 12 | [Convolutional Neural Networks for Sentence Classification](https://arxiv.org/pdf/1408.5882.pdf) | [![colab link](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jaygala24/pytorch-implementations/blob/master/Convolutional%20Neural%20Networks%20for%20Sentence%20Classification.ipynb) 13 | [Neural Machine Translation by Jointly Learning to Align and Translate](https://arxiv.org/pdf/1409.0473.pdf) | [![colab link](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jaygala24/pytorch-implementations/blob/master/Neural%20Machine%20Translation%20by%20Jointly%20Learning%20to%20Align%20and%20Translate.ipynb) 14 | [Attention Is All You Need](https://arxiv.org/abs/1706.03762) | [![colab link](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jaygala24/pytorch-implementations/blob/master/Attention%20Is%20All%20You%20Need.ipynb) 15 | 16 | 17 | ### Contribution 18 | 19 | If you would like to contribute a deep learning paper implementation, then please send the pull request with the jupyter notebook filename as the paper name. Maintaining a single repository showcasing different paper implementations would help other practitioners. 20 | -------------------------------------------------------------------------------- /Global Vectors for Word Representation.ipynb: -------------------------------------------------------------------------------- 1 | {"nbformat":4,"nbformat_minor":0,"metadata":{"accelerator":"GPU","colab":{"name":"Global Vectors for Word Representation.ipynb","provenance":[],"collapsed_sections":[]},"kernelspec":{"display_name":"Python 3","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.8.5"}},"cells":[{"cell_type":"markdown","metadata":{"id":"qyE1Ux-F_NPL"},"source":["# GloVe: Global Vectors for Word Representation\n","\n","_This notebook demonstrates the implementation of GloVe architecture proposed by [Pennington et al., 2014](https://nlp.stanford.edu/pubs/glove.pdf) for learning continuous word representations._\n"]},{"cell_type":"markdown","metadata":{"id":"OvXcHGt5_NPO"},"source":["**Note**: The notebook has been derived from my previously written blog post ([link](https://jaygala24.github.io/blog/python/pytorch/word-embeddings/word2vec/glove/2021/04/20/word_embeddings.html)).\n"]},{"cell_type":"markdown","metadata":{"id":"mzI-bvF0wZi-"},"source":["Previously, there were two main directions for learning distributed word representations: 1) count-based methods such as Latent Semantic Analysis (LSA) 2) direct prediction-based methods such as Word2Vec. Count-based methods make efficient use of statistical information about the corpus, but they do not capture the meaning of the words like word2vec and perform poorly on analogy tasks such as _**“king - queen = man - woman”**_. On the other hand, direct prediction-based methods capture the meaning of the word semantically and syntactically using local context but fail to consider the global count statistics. This is where GloVe comes into the picture and overcomes the drawbacks of both approaches by combining them. The author proposed a global log bilinear regression model to learn embeddings based on the co-occurrence of words. Note that the GloVe does not use a neural network for learning word vectors.\n"]},{"cell_type":"markdown","metadata":{"id":"2gEXhip9ZBXz"},"source":["Here we will be using text corpus of cleaned wikipedia articles provided by Matt Mahoney.\n"]},{"cell_type":"code","metadata":{"id":"4uvUbYanl_Ye"},"source":["!wget https://s3.amazonaws.com/video.udacity-data.com/topher/2018/October/5bbe6499_text8/text8.zip\n","!unzip text8.zip"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"4A7yH9Gw_NPV"},"source":["### Imports\n"]},{"cell_type":"code","metadata":{"id":"Z9vQnTNHl_WC"},"source":["%matplotlib inline\n","%config InlineBackend.figure_format = \"retina\"\n","\n","import time\n","import random\n","from collections import Counter, defaultdict\n","\n","import numpy as np\n","import matplotlib.pyplot as plt\n","from sklearn.manifold import TSNE\n","\n","import torch\n","import torch.nn as nn\n","import torch.optim as optim\n","import torch.nn.functional as F"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"PsPQvNL3l_UP"},"source":["# check if gpu is available since training is faster\n","device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"y35N6tE7_NPY"},"source":["### Data Preprocessing and Loading\n"]},{"cell_type":"markdown","metadata":{"id":"dLqfygcLyCLq"},"source":["#### Co-occurrence matrix\n","\n","The authors used a co-occurrence matrix with a context window of fixed size $m$ to learn the word embeddings. Let's try to generate this matrix for the below toy example with a context window of size 2:\n","- I like deep learning\n","- I like NLP\n","- I enjoy flying\n"]},{"cell_type":"markdown","metadata":{"id":"8zULhp8T5vzN"},"source":["![co-occurrence matrix example]()\n","\n","(image source: https://stanford.io/3n4FH4H)\n"]},{"cell_type":"code","metadata":{"id":"30vSye4Yn4ji"},"source":["class GloVeDataset(object):\n"," def __init__(self, corpus, min_count=5, window_size=5):\n"," \"\"\" Prepares the training data for the glove model.\n"," Params:\n"," corpus (string): corpus of words\n"," min_count (int): words with minimum occurrence to consider\n"," window_size (int): context window size for generating co-occurrence matrix\n"," \"\"\"\n"," self.window_size = window_size\n"," self.min_count = min_count\n","\n"," tokens = corpus.split(\" \")\n"," word_counts = Counter(tokens)\n"," # only consider the words that occur more than 5 times in the corpus \n"," word_counts = Counter({word:count for word, count in word_counts.items() if count >= min_count})\n"," \n"," self.word2idx = {word: idx for idx, (word, _) in enumerate(word_counts.most_common())}\n"," self.idx2word = {idx: word for word, idx in self.word2idx.items()}\n","\n"," # create the training corpus\n"," self.token_ids = [self.word2idx[word] for word in tokens if word in self.word2idx]\n","\n"," # create the co-occurrence matrix for corpus\n"," self.create_cooccurrence_matrix()\n","\n","\n"," def create_cooccurrence_matrix(self):\n"," \"\"\" Creates the co-occurence matrix of center and context words based on the context window size.\n"," \"\"\"\n"," cooccurrence_counts = defaultdict(Counter)\n"," for current_idx, word in enumerate(self.token_ids):\n"," # find the start and end of context window\n"," left_boundary = max(current_idx - self.window_size, 0)\n"," right_boundary = min(current_idx + self.window_size + 1, len(self.token_ids))\n","\n"," # obtain the context words and center words based on context window\n"," context_word_ids = self.token_ids[left_boundary:current_idx] + self.token_ids[current_idx + 1:right_boundary]\n"," center_word_id = self.token_ids[current_idx]\n","\n"," for idx, context_word_id in enumerate(context_word_ids):\n"," if current_idx != idx:\n"," # add (1 / distance from center word) for this pair\n"," cooccurrence_counts[center_word_id][context_word_id] += 1 / abs(current_idx - idx)\n"," \n"," # create tensors for input word ids, output word ids and their co-occurence count\n"," in_ids, out_ids, counts = [], [], []\n"," for center_word_id, counter in cooccurrence_counts.items():\n"," for context_word_id, count in counter.items():\n"," in_ids.append(center_word_id)\n"," out_ids.append(context_word_id)\n"," counts.append(count)\n","\n"," self.in_ids = torch.tensor(in_ids, dtype=torch.long)\n"," self.out_ids = torch.tensor(out_ids, dtype=torch.long)\n"," self.cooccurrence_counts = torch.tensor(counts, dtype=torch.float)\n","\n","\n"," def get_batches(self, batch_size):\n"," \"\"\" Creates the batches for training the network.\n"," Params:\n"," batch_size (int): size of the batch\n"," Returns:\n"," batch (torch tensor of shape (batch_size, 3)): tensor of word pair ids and \n"," co-occurence counts for a given batch\n"," \"\"\"\n"," random_ids = torch.tensor(np.random.choice(len(self.in_ids), len(self.in_ids), replace=False), dtype=torch.long)\n","\n"," for i in range(0, len(random_ids), batch_size):\n"," batch_ids = random_ids[i: i+batch_size]\n"," yield self.in_ids[batch_ids], self.out_ids[batch_ids], self.cooccurrence_counts[batch_ids]\n"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"FWvhJYihAm0t"},"source":["# read the file and initialize the GloVeDataset\n","with open(\"text8\", encoding=\"utf-8\") as f:\n"," corpus = f.read()\n","\n","dataset = GloVeDataset(corpus)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"QsjjWFkw-fxO"},"source":["### GloVe Model\n","\n","Before we move ahead, let's get familiarized with some notations.\n","- $X$ denotes the word-word co-occurrence matrix\n","- $X_{ij}$ denotes the number of times word $j$ occurs in the context of word $i$\n","- $X_i$ = $\\sum_{k}{X_{ik}}$ denotes the number of times any word $k$ appearing in context of word $i$ and $k$ represents the total number of distinct words that appear in context of word $i$)\n","- $P_{ij} = P(j | i) = \\frac{X_{ij}}{X_i}$ denotes the co-occurence probablity i.e. probability that word $j$ appears in the context of word $i$\n","\n","The denominator term in the co-occurrence probability accounts for global statistics, which word2vec does not uses. The main idea behind the GloVe is to encode meaning using the ratios of co-occurrence probabilities. Let's understand the above by deriving the linear meaning components for the following words based on co-occurrence probability.\n"]},{"cell_type":"markdown","metadata":{"id":"jk6DHrzi6gYg"},"source":["![co-occurence probabilities example]()\n","\n","(image source: http://nlp.stanford.edu/pubs/glove.pdf)\n"]},{"cell_type":"markdown","metadata":{"id":"PloC5f5YCK39"},"source":["The matrix shows the co-occurrence probabilities for the words from the concept of the thermodynamic phases of water (i.e., $ice$ and $steam$). The first two rows represent the co-occurrence probabilities for the words $ice$ and $steam$, whereas the last row represents their ratios. We can observe the following:\n","- ratio is not neural for closely related words such as $solid$ and $ice$ or $gas$ and $steam$\n","- ratio is neutral for words relevant to $ice$ and $steam$ both or not completely irrelevant to both\n"," \n","The ratio of co-occurrence proababilities is a good starting point for learning word embeddings. Let's start with the most general function $F$ parametrized by 3 word vectors ($w_i$, $w_j$ and $\\tilde{w_k}$) given below:\n"," \n","$$\n","F(w_i, w_j, \\tilde{w_k}) = \\frac{P_{ik}}{P_{jk}}\n","$$\n"," \n","where $w, \\tilde{w} \\in \\mathrm{R^d}$ and $\\tilde{w}$ represent the separate context words.\n"," \n","How do we choose $F$?\n"," \n","There can be many possibilities for choosing $F$ but imposing some constraints allows us to restrict $F$ and select a unique choice. The goal is to learn word vectors (embeddings) that can be projected in the word vector space. These vector spaces are inherently linear, i.e., think of vectors as a line in $\\mathrm{R^d}$ space, so the most intuitive way is to take vector differences which makes our function $F$ as follows:\n"," \n","$$\n","F(w_i - w_j, \\tilde{w_k}) = \\frac{P_{ik}}{P_{jk}}\n","$$\n"," \n","We see that the right-hand side of the above equation is a scalar. Choosing a complex function such as a neural network would introduce non-linearities since our primary goal is to capture the linear meaning components from word vector space. Here, we take dot product on the left-hand side to make it a scalar similar to the right-hand side.\n"," \n","$$\n","F((w_i - w_j)^T \\tilde{w_k}) = \\frac{P_{ik}}{P_{jk}}\n","$$\n"," \n","We also need to preserve symmetry for the distinction between a word and a context word which means that if $ice$ can be used as a context word for $water$, then $water$ can also be used as a context word for $ice$. In a simple, it can be expressed as $w \\leftrightarrow \\tilde{w}$. This is also evident from our co-occurrence matrix since $X \\leftrightarrow X^T$. In order to restore the symmetry, we require that function $F$ is a homomorphism between groups $(\\mathrm{R, +})$ and $(\\mathrm{R, \\times})$.\n"," \n","> _Given two groups, $\\small (G, ∗)$ and $\\small (H, \\cdot)$, a group homomorphism from $\\small (G, ∗)$ to $\\small (H, \\cdot)$ is a function $\\small h : G \\rightarrow H$ such that for all $u$ and $v$ in $\\small G$ it holds that $\\small h(u * v) = h(u) \\cdot h(v)$_\n"," \n","$$\n","\\begin{align}\n","F((w_i - w_j)^T \\tilde{w_k}) &= F(w_i^T \\tilde{w_k} + (-w_j^T \\tilde{w_k})) \\\\\n"," &= F(w_i^T \\tilde{w_k}) \\times F(-w_j^T \\tilde{w_k}) \\\\\n"," &= F(w_i^T \\tilde{w_k}) \\times F(w_j^T \\tilde{w_k})^{-1} \\\\\n"," &= \\frac{F(w_i^T \\tilde{w_k})}{F(w_j^T \\tilde{w_k})} \\\\\n","\\end{align}\n","$$\n"," \n","So if we recall the $F$ in terms of co-occurrence probabilities, we get the following:\n"," \n","$$\n","F(w_i^T \\tilde{w_k}) = P_{ik} = \\frac{X_{ik}}{X_i}\n","$$\n"," \n","Since we are expressing $F$ in terms of probability which is a non-negative term, so we apply exponential to dot product $w_i^T \\tilde{w_k}$ and then take logarithm on both sides.\n"," \n","$$\n","w_i^T \\tilde{w_k} = log(P_{ik}) = log(X_{ik}) - log(X_i)\n","$$\n"," \n","On the right hand, the term $log(X_i)$ is independent of $k$ so it can be absorbed into a bias $b_i$ for $w_i$. Finally, we add bias $\\tilde{b_k}$ for $\\tilde{w_k}$ to restore the symmetry.\n"," \n","$$\n","w_i^T \\tilde{w_k} + b_i + \\tilde{b_k} = log(X_{ik})\n","$$\n"," \n","The above equation leads to our objective function, a weighted least squares regression model where we use the weighting function $f(X_{ij})$ for word-word co-occurrences.\n"," \n","$$\n","J = \\sum_{i,j = 1}^{V}f(X_{ij}) (w_i^T \\tilde{w_k} + b_i + \\tilde{b_k} - logX_{ik})^2\n","$$\n"," \n","where $V$ is the size of the vocabulary.\n","\n","Here, the weighting function is defined as follows:\n","\n","$$\n","f(x) = \\begin{cases}\n"," (x / x_{max})^{\\alpha} & \\text{if}\\ x < x_{max} \\\\\n"," 1 & \\text{otherwise}\n"," \\end{cases}\n","$$\n","\n","where $x_{max}$ is the cutoff of the weighting function and $\\alpha$ is power scaling similar to Word2Vec.\n"]},{"cell_type":"code","metadata":{"id":"erZmvpRB5B4r"},"source":["class GloVeModel(nn.Module):\n"," def __init__(self, vocab_size, embed_dim, x_max=100, alpha=0.75):\n"," \"\"\" GloVe model for learning word embeddings. Uses the approach of predicting \n"," context words given the center word.\n"," Params:\n"," vocab_size (int): number of words in the vocabulary\n"," embed_dim (int): embeddings of dimension to be generated\n"," x_max (int): cutoff of the weighting function\n"," alpha (int): parameter of the weighting funtion\n"," \"\"\"\n"," super(GloVeModel, self).__init__()\n"," self.vocab_size = vocab_size\n"," self.embed_dim = embed_dim\n"," self.x_max = x_max\n"," self.alpha = alpha\n","\n"," # embedding layers for input (center) and output (context) words along with biases\n"," self.embed_in = nn.Embedding(vocab_size, embed_dim)\n"," self.embed_out = nn.Embedding(vocab_size, embed_dim)\n"," self.bias_in = nn.Embedding(vocab_size, 1)\n"," self.bias_out = nn.Embedding(vocab_size, 1)\n","\n"," # initialize the embeddings with uniform dist and set bias to zero\n"," self.embed_in.weight.data.uniform_(-1, 1)\n"," self.embed_out.weight.data.uniform_(-1, 1)\n"," self.bias_in.weight.data.zero_()\n"," self.bias_out.weight.data.zero_()\n","\n"," \n"," def forward(self, in_ids, out_ids, cooccurrence_counts):\n"," \"\"\" Trains the GloVe model and updates the weights based on the\n"," criterion.\n"," Params:\n"," in_ids (torch tensor of shape (batch_size,)): indexes of the input words for a batch\n"," out_ids (torch tensor of shape (batch_size,)): indexes of the output words for a batch\n"," cooccurrence_counts (torch tensor of shape (batch_size,)): co-occurence count of input \n"," and output words for a batch\n"," \"\"\"\n"," emb_in = self.embed_in(in_ids)\n"," emb_out = self.embed_out(out_ids)\n"," b_in = self.bias_in(in_ids)\n"," b_out = self.bias_out(out_ids)\n","\n"," # add 1 to counts i.e. cooccurrences in order to avoid log(0) case\n"," cooccurrence_counts += 1\n","\n"," # count weight factor\n"," weight_factor = torch.pow(cooccurrence_counts / self.x_max, self.alpha)\n"," weight_factor[cooccurrence_counts > 1] = 1\n"," \n"," # calculate the distance between the input and output embeddings\n"," emb_prods = torch.sum(emb_in * emb_out, dim=1)\n"," log_cooccurrences = torch.log(cooccurrence_counts)\n"," distances = (emb_prods + b_in + b_out - log_cooccurrences) ** 2\n","\n"," return torch.mean(weight_factor * distances)\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"dLD163YZ_NP1"},"source":["### Training\n"]},{"cell_type":"code","metadata":{"id":"W0mKaLxLnz-H"},"source":["# intialize the model and optimizer\n","vocab_size = len(dataset.word2idx)\n","embed_dim = 300\n","model = GloVeModel(vocab_size, embed_dim).to(device)\n","optimizer = optim.Adagrad(model.parameters(), lr=0.05)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"LoLR9QlAnz8i"},"source":["# training the network\n","n_epochs = 5\n","batch_size = 512\n","\n","print(\"-\" * 60)\n","print(\"Start of training\")\n","print(\"-\" * 60)\n","\n","for epoch in range(n_epochs):\n"," losses = []\n"," start = time.time()\n","\n"," for input_word_ids, target_word_ids, cooccurrence_counts in dataset.get_batches(batch_size):\n"," # load tensor to GPU\n"," input_word_ids = input_word_ids.to(device)\n"," target_word_ids = target_word_ids.to(device)\n"," cooccurrence_counts = cooccurrence_counts.to(device)\n"," \n"," # forward pass\n"," loss = model.forward(input_word_ids, target_word_ids, cooccurrence_counts)\n","\n"," # backward pass, optimize\n"," optimizer.zero_grad()\n"," loss.backward()\n"," optimizer.step()\n","\n"," losses.append(loss.item())\n"," \n"," end = time.time()\n","\n"," print(f\"Epochs: {epoch + 1}/{n_epochs}\\tAvg training loss: {np.mean(losses):.6f}\\tEllapsed time: {(end - start):.0f} s\")\n","\n","print(\"-\" * 60)\n","print(\"End of training\")\n","print(\"-\" * 60)\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"kTxZa53l_NP4"},"source":["### Inference\n"]},{"cell_type":"code","metadata":{"id":"NF6CZu_Ynz6e"},"source":["# get the trained embeddings from the model\n","emb_in = model.embed_in.weight.to(\"cpu\").data.numpy()\n","emb_out = model.embed_out.weight.to(\"cpu\").data.numpy()\n","embeddings = emb_in + emb_out\n","\n","# number of words to be visualized\n","viz_words = 200\n","\n","# projecting the embedding dimension from 300 to 2\n","tsne = TSNE()\n","embed_tsne = tsne.fit_transform(embeddings[:viz_words, :])\n","\n","# plot the projected embeddings\n","plt.figure(figsize=(16, 16))\n","for idx in range(viz_words):\n"," plt.scatter(*embed_tsne[idx, :], color=\"blue\")\n"," plt.annotate(dataset.idx2word[idx], (embed_tsne[idx, 0], embed_tsne[idx, 1]), alpha=0.7)\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"oEXgG-hDIMdT"},"source":["## References\n","\n","1. [GloVe: Global Vectors for Word Representation](https://nlp.stanford.edu/pubs/glove.pdf)\n","2. [Group homomorphism](https://en.wikipedia.org/wiki/Group_homomorphism)\n","3. [Homomorphism in GloVe](https://datascience.stackexchange.com/questions/27042/glove-vector-representation-homomorphism-question)\n","4. [A GloVe Implementation in Python](http://www.foldl.me/2014/glove-python/)\n","5. [Pytorch Global Vectors for Word Representation](https://github.com/kefirski/pytorch_GloVe)\n"]}]} -------------------------------------------------------------------------------- /Convolutional Neural Networks for Sentence Classification.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "yeMtUsLNDOsR" 7 | }, 8 | "source": [ 9 | "# Convolutional Neural Networks for Sentence Classification\n", 10 | "\n", 11 | "_This notebook demonstrates the implementation of Convolutional Neural Networks for Sentence Classification proposed by [Kim, 2014](https://arxiv.org/pdf/1408.5882.pdf)._\n" 12 | ] 13 | }, 14 | { 15 | "cell_type": "markdown", 16 | "metadata": {}, 17 | "source": [ 18 | "Traditionally, Convolutional Neural Networks (CNNs) were invented for computed vision and achieved state-of-the-art performance in computer vision and speech recognition. CNNs act like feature extractors where they scan different regions of the image using the kernel and output of each layer is passed to the next CNN layer. The lower layers in CNNs are useful at detecting low-level features such as edges whereas the higher layers in CNNs are useful at detecting facial features such as eyes, nose, ear, etc.\n", 19 | "\n", 20 | "The above paper proposed by [Kim, 2014](https://arxiv.org/pdf/1408.5882.pdf) was one of the earliest work demonstrating the applications of CNNs in NLP tasks, specifically text classification. In this case, CNN is used as a feature extractor which encodes semantic features of text and then these features are fed into classifier.\n", 21 | "\n", 22 | "RNNs extract the prefix sequence features which might not be always helpful as preserving complete context is very difficult for longer sequences and we may get wrong prediction. CNNs help us to extract subsequence features which are beneficial for classification tasks as the dominant features context would be preserved.\n", 23 | "\n", 24 | "We will build a CNN based sentiment classifier for IMDB movie reviews.\n" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": 1, 30 | "metadata": { 31 | "executionInfo": { 32 | "elapsed": 5029, 33 | "status": "ok", 34 | "timestamp": 1622648644726, 35 | "user": { 36 | "displayName": "Jay Gala", 37 | "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Ghr20IAxple1q2jggw00r147YTV5sjlqKrRANzYVA=s64", 38 | "userId": "11525158385229787114" 39 | }, 40 | "user_tz": -330 41 | }, 42 | "id": "QM7kOJJLFB28" 43 | }, 44 | "outputs": [], 45 | "source": [ 46 | "%%capture\n", 47 | "# download the spacy language models for english\n", 48 | "!python -m spacy download en --quiet" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": 2, 54 | "metadata": { 55 | "executionInfo": { 56 | "elapsed": 2867, 57 | "status": "ok", 58 | "timestamp": 1622648647583, 59 | "user": { 60 | "displayName": "Jay Gala", 61 | "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Ghr20IAxple1q2jggw00r147YTV5sjlqKrRANzYVA=s64", 62 | "userId": "11525158385229787114" 63 | }, 64 | "user_tz": -330 65 | }, 66 | "id": "ssZUPl3UFNC6" 67 | }, 68 | "outputs": [], 69 | "source": [ 70 | "%%capture\n", 71 | "!pip install torchtext --upgrade" 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": 3, 77 | "metadata": { 78 | "executionInfo": { 79 | "elapsed": 154436, 80 | "status": "ok", 81 | "timestamp": 1622648802013, 82 | "user": { 83 | "displayName": "Jay Gala", 84 | "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Ghr20IAxple1q2jggw00r147YTV5sjlqKrRANzYVA=s64", 85 | "userId": "11525158385229787114" 86 | }, 87 | "user_tz": -330 88 | }, 89 | "id": "-0IQi07nz9J6" 90 | }, 91 | "outputs": [], 92 | "source": [ 93 | "%%capture\n", 94 | "!wget https://dl.fbaipublicfiles.com/fasttext/vectors-english/crawl-300d-2M.vec.zip\n", 95 | "!unzip crawl-300d-2M.vec.zip\n", 96 | "!rm -rf crawl-300d-2M.vec.zip" 97 | ] 98 | }, 99 | { 100 | "cell_type": "markdown", 101 | "metadata": { 102 | "id": "B--2aTzsFPgA" 103 | }, 104 | "source": [ 105 | "## Imports\n" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": 4, 111 | "metadata": { 112 | "executionInfo": { 113 | "elapsed": 5741, 114 | "status": "ok", 115 | "timestamp": 1622648807738, 116 | "user": { 117 | "displayName": "Jay Gala", 118 | "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Ghr20IAxple1q2jggw00r147YTV5sjlqKrRANzYVA=s64", 119 | "userId": "11525158385229787114" 120 | }, 121 | "user_tz": -330 122 | }, 123 | "id": "ZyKMGPOjFUu5" 124 | }, 125 | "outputs": [], 126 | "source": [ 127 | "import random\n", 128 | "import numpy as np\n", 129 | "\n", 130 | "import torch\n", 131 | "import torch.nn as nn\n", 132 | "import torch.optim as optim\n", 133 | "import torch.nn.functional as F\n", 134 | "\n", 135 | "import spacy\n", 136 | "import torchtext\n", 137 | "import torchtext.vocab as vocab\n", 138 | "from torchtext.legacy.datasets import IMDB\n", 139 | "from torchtext.legacy.data import Field, LabelField, BucketIterator\n", 140 | "\n", 141 | "import warnings\n", 142 | "from tqdm.notebook import tqdm\n", 143 | "warnings.filterwarnings('ignore')" 144 | ] 145 | }, 146 | { 147 | "cell_type": "code", 148 | "execution_count": 5, 149 | "metadata": { 150 | "executionInfo": { 151 | "elapsed": 23, 152 | "status": "ok", 153 | "timestamp": 1622648807739, 154 | "user": { 155 | "displayName": "Jay Gala", 156 | "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Ghr20IAxple1q2jggw00r147YTV5sjlqKrRANzYVA=s64", 157 | "userId": "11525158385229787114" 158 | }, 159 | "user_tz": -330 160 | }, 161 | "id": "DjC3C6_2FYB6" 162 | }, 163 | "outputs": [], 164 | "source": [ 165 | "# for reproducibility\n", 166 | "# refer https://pytorch.org/docs/stable/notes/randomness.html\n", 167 | "SEED = 42\n", 168 | "\n", 169 | "random.seed(SEED)\n", 170 | "np.random.seed(SEED)\n", 171 | "torch.manual_seed(SEED)\n", 172 | "torch.cuda.manual_seed(SEED)\n", 173 | "torch.backends.cudnn.deterministic = True" 174 | ] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "execution_count": 6, 179 | "metadata": { 180 | "executionInfo": { 181 | "elapsed": 16, 182 | "status": "ok", 183 | "timestamp": 1622648807740, 184 | "user": { 185 | "displayName": "Jay Gala", 186 | "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Ghr20IAxple1q2jggw00r147YTV5sjlqKrRANzYVA=s64", 187 | "userId": "11525158385229787114" 188 | }, 189 | "user_tz": -330 190 | }, 191 | "id": "KV2NLb9oFaav" 192 | }, 193 | "outputs": [], 194 | "source": [ 195 | "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')" 196 | ] 197 | }, 198 | { 199 | "cell_type": "markdown", 200 | "metadata": { 201 | "id": "WjxWOXY2F3Fg" 202 | }, 203 | "source": [ 204 | "## Data Preprocessing and Loading\n" 205 | ] 206 | }, 207 | { 208 | "cell_type": "code", 209 | "execution_count": 32, 210 | "metadata": { 211 | "executionInfo": { 212 | "elapsed": 810146, 213 | "status": "ok", 214 | "timestamp": 1622650907554, 215 | "user": { 216 | "displayName": "Jay Gala", 217 | "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Ghr20IAxple1q2jggw00r147YTV5sjlqKrRANzYVA=s64", 218 | "userId": "11525158385229787114" 219 | }, 220 | "user_tz": -330 221 | }, 222 | "id": "x9DZy-LFH7Bc" 223 | }, 224 | "outputs": [], 225 | "source": [ 226 | "# create field objects for text and label\n", 227 | "review = Field(tokenize='spacy', batch_first=True)\n", 228 | "sentiment = LabelField(batch_first=True)\n", 229 | "\n", 230 | "# load the imdb dataset\n", 231 | "train_data, test_data = IMDB.splits(text_field=review, label_field=sentiment)\n", 232 | "train_data, valid_data = train_data.split()" 233 | ] 234 | }, 235 | { 236 | "cell_type": "code", 237 | "execution_count": 33, 238 | "metadata": { 239 | "executionInfo": { 240 | "elapsed": 71638, 241 | "status": "ok", 242 | "timestamp": 1622650979181, 243 | "user": { 244 | "displayName": "Jay Gala", 245 | "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Ghr20IAxple1q2jggw00r147YTV5sjlqKrRANzYVA=s64", 246 | "userId": "11525158385229787114" 247 | }, 248 | "user_tz": -330 249 | }, 250 | "id": "9eFs-n60Jdrc" 251 | }, 252 | "outputs": [], 253 | "source": [ 254 | "# load the pretrained fasttext embeddings and build the vocabulary\n", 255 | "en_fast_embed = vocab.Vectors(name='crawl-300d-2M.vec', cache='.', unk_init = torch.Tensor.normal_)\n", 256 | "review.build_vocab(train_data, max_size=25000, vectors=en_fast_embed)\n", 257 | "sentiment.build_vocab(train_data)\n", 258 | "\n", 259 | "del en_fast_embed" 260 | ] 261 | }, 262 | { 263 | "cell_type": "code", 264 | "execution_count": 34, 265 | "metadata": { 266 | "colab": { 267 | "base_uri": "https://localhost:8080/" 268 | }, 269 | "executionInfo": { 270 | "elapsed": 39, 271 | "status": "ok", 272 | "timestamp": 1622650979185, 273 | "user": { 274 | "displayName": "Jay Gala", 275 | "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Ghr20IAxple1q2jggw00r147YTV5sjlqKrRANzYVA=s64", 276 | "userId": "11525158385229787114" 277 | }, 278 | "user_tz": -330 279 | }, 280 | "id": "LbQNCow6HuMp", 281 | "outputId": "0254c26d-ab4e-4b55-a9ad-56377b8f0e5b" 282 | }, 283 | "outputs": [ 284 | { 285 | "name": "stdout", 286 | "output_type": "stream", 287 | "text": [ 288 | "# of training examples: 17500\n", 289 | "# of validation examples: 7500\n", 290 | "# of testing examples: 25000\n", 291 | "dict_keys(['text', 'label'])\n", 292 | "dict_values([['The', 'fact', 'that', 'someone', 'actually', 'spent', 'money', 'on', 'such', 'a', 'bad', 'script', ',', 'is', 'beyond', 'me', '.', 'This', 'really', 'must', 'be', 'one', 'of', 'the', 'worst', 'films', ',', 'in', 'addition', 'to', '\"', 'Haunted', 'Highway', '\"', 'I', 'have', 'ever', 'seen', '.', 'BAD', 'actors', ',', 'and', 'a', 'really', 'bad', 'story', '.', 'There', \"'s\", 'no', 'normal', 'reactions', 'to', 'any', 'event', 'in', 'this', 'film', ',', 'and', 'even', 'though', 'it', \"'s\", 'Halloween', ',', 'normal', 'people', 'would', 'have', 'bigger', 'reactions', 'when', 'they', \"'re\", 'witnessing', 'their', 'father', 'being', 'killed', ',', 'not', 'to', 'mention', 'gutted', ',', 'people', 'with', 'tape', 'covering', 'their', 'airways', ',', 'not', 'being', 'able', 'to', 'breathe', '(', 'in', 'a', 'room', 'with', 'at', 'least', '50', 'people', 'I', 'might', 'add', ')', 'and', 'some', 'person', 'dressed', 'up', 'as', 'Satan', 'dragging', 'dead', 'people', 'out', 'of', 'his', 'house', ',', 'even', 'an', '8', 'year', 'old', 'would', 'see', 'the', 'difference', 'between', 'a', 'doll', 'and', 'a', 'person', '.', 'Not', 'to', 'mention', 'the', 'fact', 'that', 'no', 'one', 'could', 'possibly', 'be', 'that', 'naive', 'and', 'dumb', 'to', 'believe', 'the', 'reality', 'of', 'Satan', 'and', 'Jesus', \"'\", 'appearances', 'on', 'the', 'same', 'day', ',', 'like', 'this', 'kid', 'does', '.', 'When', 'i', 'was', '8', ',', 'I', 'sure', 'had', 'more', 'brains', 'than', 'that', '.', '<', 'br', '/>But', ',', 'the', 'really', 'stupid', 'thing', 'is', 'that', 'everyone', 'else', 'seems', 'to', 'be', 'falling', 'for', 'this', 'mute', 'Satan', 'look', '-', 'alike', 'as', 'well', ',', 'no', 'questions', 'asked', '.', 'The', 'question', 'throughout', 'the', 'film', 'is', ',', 'is', 'it', 'really', 'Satan', ',', 'or', 'is', 'it', 'some', 'crazy', 'person', 'killing', 'people', 'off', 'whenever', 'he', 'feels', 'like', 'it', '?', 'Well', ',', 'he', \"'s\", 'got', 'human', 'hands', ',', 'arms', ',', 'built', 'and', 'whatever', ',', 'so', 'I', 'guess', 'he', \"'s\", 'supposed', 'to', 'be', 'in', 'the', 'movie', 'as', 'well', ',', 'otherwise', 'they', 'did', 'a', 'lousy', 'job', 'concealing', 'it', '.', 'Then', ',', 'with', 'this', 'person', 'being', 'human', 'and', 'all', ',', 'he', 'was', 'able', 'to', 'kill', 'an', 'old', 'lady', ',', 'a', 'man', 'and', 'his', 'mistress', ',', '5', '(', '!', '!', '?', '?', '?', ')', 'cops', '(', 'all', 'with', 'guns', 'and', 'training', 'i', 'presume', ')', ',', 'and', 'a', 'few', 'other', 'people', '.....', 'and', 'obviously', 'everyone', 'was', 'just', 'standing', 'there', 'waiting', 'for', 'him', ',', 'or', 'what?The', 'whole', 'concept', 'and', 'way', 'of', 'telling', 'the', 'story', 'is', 'absolutely', 'the', 'worst', 'thing', 'I', \"'ve\", 'seen', ',', 'and', 'I', 'would', 'never', 'recommend', 'anyone', 'to', 'waste', '1', 'hour', 'and', '30', 'minutes', 'of', 'their', 'lives', 'to', 'watch', 'this', 'total', 'crap', '.'], 'neg'])\n" 293 | ] 294 | } 295 | ], 296 | "source": [ 297 | "print(f'# of training examples: {len(train_data.examples)}')\n", 298 | "print(f'# of validation examples: {len(valid_data.examples)}')\n", 299 | "print(f'# of testing examples: {len(test_data.examples)}')\n", 300 | "\n", 301 | "print(train_data[0].__dict__.keys())\n", 302 | "print(train_data[0].__dict__.values())" 303 | ] 304 | }, 305 | { 306 | "cell_type": "code", 307 | "execution_count": 35, 308 | "metadata": { 309 | "executionInfo": { 310 | "elapsed": 23, 311 | "status": "ok", 312 | "timestamp": 1622650979186, 313 | "user": { 314 | "displayName": "Jay Gala", 315 | "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Ghr20IAxple1q2jggw00r147YTV5sjlqKrRANzYVA=s64", 316 | "userId": "11525158385229787114" 317 | }, 318 | "user_tz": -330 319 | }, 320 | "id": "j_7k6_xaKd6e" 321 | }, 322 | "outputs": [], 323 | "source": [ 324 | "BATCH_SIZE = 64\n", 325 | "\n", 326 | "train_iterator, valid_iterator, test_iterator = BucketIterator.splits(\n", 327 | " (train_data, valid_data, test_data),\n", 328 | " batch_size=BATCH_SIZE,\n", 329 | " device=device\n", 330 | ")" 331 | ] 332 | }, 333 | { 334 | "cell_type": "code", 335 | "execution_count": 36, 336 | "metadata": { 337 | "colab": { 338 | "base_uri": "https://localhost:8080/" 339 | }, 340 | "executionInfo": { 341 | "elapsed": 23, 342 | "status": "ok", 343 | "timestamp": 1622650979187, 344 | "user": { 345 | "displayName": "Jay Gala", 346 | "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Ghr20IAxple1q2jggw00r147YTV5sjlqKrRANzYVA=s64", 347 | "userId": "11525158385229787114" 348 | }, 349 | "user_tz": -330 350 | }, 351 | "id": "QE2tx68lLCNf", 352 | "outputId": "987ffbb6-8da9-42c2-94f1-e3da1b74fc64" 353 | }, 354 | "outputs": [ 355 | { 356 | "name": "stdout", 357 | "output_type": "stream", 358 | "text": [ 359 | "{'text': torch.Size([64, 1150]), 'label': torch.Size([64])}\n" 360 | ] 361 | } 362 | ], 363 | "source": [ 364 | "# sanity check to see if data loader is working\n", 365 | "x = next(iter(train_iterator))\n", 366 | "\n", 367 | "print({'text': x.text.shape, 'label': x.label.shape})" 368 | ] 369 | }, 370 | { 371 | "cell_type": "markdown", 372 | "metadata": { 373 | "id": "ZtW9HLyyLPOO" 374 | }, 375 | "source": [ 376 | "## Model Architecture\n" 377 | ] 378 | }, 379 | { 380 | "cell_type": "markdown", 381 | "metadata": {}, 382 | "source": [ 383 | "![cnn_architecture]( \"CNN architecture for sentence classification\")\n", 384 | "\n", 385 | "(source: https://arxiv.org/pdf/1510.03820.pdf)\n" 386 | ] 387 | }, 388 | { 389 | "cell_type": "markdown", 390 | "metadata": {}, 391 | "source": [ 392 | "CNN based text classifier reads the input sequence and generates a feature representation for the subsequences which are used for classification.\n", 393 | "\n", 394 | "For the implementation in this notebook, we first apply the embedding layer on the input sequence $x$ to obtain the embedded representation $x_{emb}$. Next, we apply 1D convolutions of different filter sizes to obtain feature map $x_{conv_i}$ corresponding to each filter $i$. Further, we apply ReLU activation and max pooling over time to reduce each feature map to a single scalar $x_{pool_i}$ and concatenate all these scalar values $x_concat$. Finally, linear transformation is applied over the concatenated representation using the weight matrix $W$ to compute logits. Either sigmoid or softmax is applied to these computed logits depending on the classification problem i.e. binary or multi-class.\n", 395 | "\n", 396 | "$$\n", 397 | "x_{emb} = f_{embedding}(x) \\\\\n", 398 | "x_{conv_i} = \\text{conv1D}(x_{emb}) \\\\\n", 399 | "x_{pool_i} = \\text{maxpool}(relu(x_{conv_i})) \\\\\n", 400 | "x_{concat} = [x_{pool_i}; \\dots; x_{pool_j}] \\\\\n", 401 | "\\text{logits} = \\text{dropout}(Wx_{concat})\n", 402 | "$$\n" 403 | ] 404 | }, 405 | { 406 | "cell_type": "code", 407 | "execution_count": 37, 408 | "metadata": { 409 | "executionInfo": { 410 | "elapsed": 18, 411 | "status": "ok", 412 | "timestamp": 1622650979188, 413 | "user": { 414 | "displayName": "Jay Gala", 415 | "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Ghr20IAxple1q2jggw00r147YTV5sjlqKrRANzYVA=s64", 416 | "userId": "11525158385229787114" 417 | }, 418 | "user_tz": -330 419 | }, 420 | "id": "k8D3rfn_LcH8" 421 | }, 422 | "outputs": [], 423 | "source": [ 424 | "class CNNTextClassifier(nn.Module):\n", 425 | " \"\"\" Simple 1D Convolutional Neural Network for Sentence Classification.\n", 426 | " \"\"\"\n", 427 | "\n", 428 | " def __init__(self, vocab_size, n_classes, embed_size, filter_sizes, n_filters, pad_idx, dropout_rate=0.5):\n", 429 | " super(CNNTextClassifier, self).__init__()\n", 430 | " \n", 431 | " self.vocab_size = vocab_size\n", 432 | " self.embed_size = embed_size\n", 433 | " self.filter_sizes = filter_sizes\n", 434 | " self.n_filters = n_filters\n", 435 | " self.n_classes = n_classes\n", 436 | " self.pad_idx = pad_idx\n", 437 | " self.dropout_rate = dropout_rate\n", 438 | "\n", 439 | " self.embedding = nn.Embedding(vocab_size, embed_size, padding_idx=pad_idx)\n", 440 | " self.conv_list = nn.ModuleList([\n", 441 | " nn.Conv1d(in_channels=embed_size, out_channels=n_filters, kernel_size=filter_sizes[i])\n", 442 | " for i in range(len(filter_sizes))\n", 443 | " ])\n", 444 | " self.fc = nn.Linear(len(filter_sizes) * n_filters, n_classes)\n", 445 | " self.dropout = nn.Dropout(dropout_rate)\n", 446 | "\n", 447 | " \n", 448 | " def load_pretrained_embeddings(self, embeddings, fine_tune=False):\n", 449 | " self.embedding.weight = nn.Parameter(embeddings)\n", 450 | " for p in self.embedding.parameters():\n", 451 | " p.requires_grad = fine_tune\n", 452 | " \n", 453 | "\n", 454 | " def _conv_block(self, x, conv_layer):\n", 455 | " \"\"\" x: (batch_size, embed_size, seq_len)\n", 456 | " \"\"\"\n", 457 | " x_conv = F.relu(conv_layer(x)) # (batch_size, n_filters, seq_len - filter_sizes[i] + 1)\n", 458 | " x_pool = F.max_pool1d(x_conv, kernel_size=x_conv.shape[2]).squeeze(2) # (batch_size, n_filters)\n", 459 | " return x_pool\n", 460 | "\n", 461 | "\n", 462 | " def forward(self, x):\n", 463 | " \"\"\" x: (batch_size, seq_len)\n", 464 | " \"\"\"\n", 465 | " embed = self.embedding(x) # (batch_size, seq_len, embed_size)\n", 466 | " embed = embed.permute(0, 2, 1) # (batch_size, embed_size, seq_len)\n", 467 | "\n", 468 | " conv_stack = [\n", 469 | " self._conv_block(embed, conv) for conv in self.conv_list\n", 470 | " ]\n", 471 | "\n", 472 | " pooled_values = torch.cat(conv_stack, dim=1) # (batch_size, sum(num_filters)) \n", 473 | "\n", 474 | " logits = self.fc(self.dropout(pooled_values)) # (batch_size, num_classes)\n", 475 | "\n", 476 | " return logits" 477 | ] 478 | }, 479 | { 480 | "cell_type": "markdown", 481 | "metadata": { 482 | "id": "iVgiu3sAX8Kt" 483 | }, 484 | "source": [ 485 | "## Helper Utilities\n" 486 | ] 487 | }, 488 | { 489 | "cell_type": "code", 490 | "execution_count": 38, 491 | "metadata": { 492 | "executionInfo": { 493 | "elapsed": 17, 494 | "status": "ok", 495 | "timestamp": 1622650979188, 496 | "user": { 497 | "displayName": "Jay Gala", 498 | "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Ghr20IAxple1q2jggw00r147YTV5sjlqKrRANzYVA=s64", 499 | "userId": "11525158385229787114" 500 | }, 501 | "user_tz": -330 502 | }, 503 | "id": "T_xZRYpYYBf6" 504 | }, 505 | "outputs": [], 506 | "source": [ 507 | "def model_summary(model):\n", 508 | " print(model)\n", 509 | " print(f'# of trainable params: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}')\n", 510 | " print(f'# of non-trainable params: {sum(p.numel() for p in model.parameters() if not p.requires_grad):,}')" 511 | ] 512 | }, 513 | { 514 | "cell_type": "code", 515 | "execution_count": 39, 516 | "metadata": { 517 | "executionInfo": { 518 | "elapsed": 17, 519 | "status": "ok", 520 | "timestamp": 1622650979189, 521 | "user": { 522 | "displayName": "Jay Gala", 523 | "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Ghr20IAxple1q2jggw00r147YTV5sjlqKrRANzYVA=s64", 524 | "userId": "11525158385229787114" 525 | }, 526 | "user_tz": -330 527 | }, 528 | "id": "tpiECqiLdF4x" 529 | }, 530 | "outputs": [], 531 | "source": [ 532 | "def accuracy(preds, y):\n", 533 | " top_preds = preds.argmax(dim=1, keepdim=True)\n", 534 | " correct = (top_preds == y.view_as(top_preds)).sum()\n", 535 | " acc = correct / len(y)\n", 536 | " return acc" 537 | ] 538 | }, 539 | { 540 | "cell_type": "markdown", 541 | "metadata": { 542 | "id": "mudqBztPX3S5" 543 | }, 544 | "source": [ 545 | "## Training\n" 546 | ] 547 | }, 548 | { 549 | "cell_type": "code", 550 | "execution_count": 46, 551 | "metadata": { 552 | "executionInfo": { 553 | "elapsed": 3, 554 | "status": "ok", 555 | "timestamp": 1622650986652, 556 | "user": { 557 | "displayName": "Jay Gala", 558 | "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Ghr20IAxple1q2jggw00r147YTV5sjlqKrRANzYVA=s64", 559 | "userId": "11525158385229787114" 560 | }, 561 | "user_tz": -330 562 | }, 563 | "id": "kL7MVl7xxpAq" 564 | }, 565 | "outputs": [], 566 | "source": [ 567 | "def train_fn(model, iterator, optimizer, criterion):\n", 568 | " model.train()\n", 569 | " epoch_loss = 0\n", 570 | " epoch_acc = 0\n", 571 | " \n", 572 | " tk0 = tqdm(iterator, total=len(iterator), position=0, leave=True)\n", 573 | "\n", 574 | " for idx, batch in enumerate(tk0):\n", 575 | " \n", 576 | " # forward pass\n", 577 | " optimizer.zero_grad()\n", 578 | " logits = model(batch.text) # (batch_size)\n", 579 | " \n", 580 | " # calcalute loss\n", 581 | " loss = criterion(logits, batch.label)\n", 582 | " \n", 583 | " # calculate accuracy\n", 584 | " acc = accuracy(logits, batch.label)\n", 585 | " \n", 586 | " # backward pass\n", 587 | " loss.backward()\n", 588 | " \n", 589 | " # update model parameters\n", 590 | " optimizer.step()\n", 591 | " \n", 592 | " epoch_loss += loss.item()\n", 593 | " epoch_acc += acc.item()\n", 594 | "\n", 595 | " tk0.update(0)\n", 596 | " \n", 597 | " tk0.close()\n", 598 | " \n", 599 | " return epoch_acc / len(iterator), epoch_loss / len(iterator)" 600 | ] 601 | }, 602 | { 603 | "cell_type": "code", 604 | "execution_count": 47, 605 | "metadata": { 606 | "executionInfo": { 607 | "elapsed": 4, 608 | "status": "ok", 609 | "timestamp": 1622650987090, 610 | "user": { 611 | "displayName": "Jay Gala", 612 | "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Ghr20IAxple1q2jggw00r147YTV5sjlqKrRANzYVA=s64", 613 | "userId": "11525158385229787114" 614 | }, 615 | "user_tz": -330 616 | }, 617 | "id": "hmytuJE8yFMn" 618 | }, 619 | "outputs": [], 620 | "source": [ 621 | "def eval_fn(model, iterator, criterion):\n", 622 | " model.eval()\n", 623 | " epoch_loss = 0\n", 624 | " epoch_acc = 0\n", 625 | " \n", 626 | " tk0 = tqdm(iterator, total=len(iterator), position=0, leave=True)\n", 627 | "\n", 628 | " with torch.no_grad():\n", 629 | " for idx, batch in enumerate(tk0):\n", 630 | " \n", 631 | " # forward pass\n", 632 | " logits = model(batch.text) # (batch_size)\n", 633 | " \n", 634 | " # calcalute loss\n", 635 | " loss = criterion(logits, batch.label)\n", 636 | " \n", 637 | " # calculate accuracy\n", 638 | " acc = accuracy(logits, batch.label)\n", 639 | "\n", 640 | " epoch_loss += loss.item()\n", 641 | " epoch_acc += acc.item()\n", 642 | "\n", 643 | " tk0.update(0)\n", 644 | " \n", 645 | " tk0.close()\n", 646 | " \n", 647 | " return epoch_acc / len(iterator), epoch_loss / len(iterator)" 648 | ] 649 | }, 650 | { 651 | "cell_type": "code", 652 | "execution_count": 48, 653 | "metadata": { 654 | "executionInfo": { 655 | "elapsed": 15, 656 | "status": "ok", 657 | "timestamp": 1622650987684, 658 | "user": { 659 | "displayName": "Jay Gala", 660 | "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Ghr20IAxple1q2jggw00r147YTV5sjlqKrRANzYVA=s64", 661 | "userId": "11525158385229787114" 662 | }, 663 | "user_tz": -330 664 | }, 665 | "id": "A2OGqhaIYXB3" 666 | }, 667 | "outputs": [], 668 | "source": [ 669 | "# hyperparameters\n", 670 | "VOCAB_SIZE = len(review.vocab)\n", 671 | "EMBED_SIZE = review.vocab.vectors.shape[1]\n", 672 | "FILTER_SIZES = [3, 4, 5]\n", 673 | "N_FILTERS = 100\n", 674 | "N_CLASSES = len(sentiment.vocab)\n", 675 | "DROPOUT_RATE = 0.5\n", 676 | "PAD_IDX = review.vocab.stoi[review.pad_token]\n", 677 | "N_EPOCHS = 20" 678 | ] 679 | }, 680 | { 681 | "cell_type": "code", 682 | "execution_count": 54, 683 | "metadata": { 684 | "colab": { 685 | "base_uri": "https://localhost:8080/" 686 | }, 687 | "executionInfo": { 688 | "elapsed": 686, 689 | "status": "ok", 690 | "timestamp": 1622651034647, 691 | "user": { 692 | "displayName": "Jay Gala", 693 | "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Ghr20IAxple1q2jggw00r147YTV5sjlqKrRANzYVA=s64", 694 | "userId": "11525158385229787114" 695 | }, 696 | "user_tz": -330 697 | }, 698 | "id": "1rmb8RcUYZl9", 699 | "outputId": "2727b9e4-f361-4ee8-e3f9-493e281c274b" 700 | }, 701 | "outputs": [ 702 | { 703 | "name": "stdout", 704 | "output_type": "stream", 705 | "text": [ 706 | "CNNTextClassifier(\n", 707 | " (embedding): Embedding(25002, 300, padding_idx=1)\n", 708 | " (conv_list): ModuleList(\n", 709 | " (0): Conv1d(300, 100, kernel_size=(3,), stride=(1,))\n", 710 | " (1): Conv1d(300, 100, kernel_size=(4,), stride=(1,))\n", 711 | " (2): Conv1d(300, 100, kernel_size=(5,), stride=(1,))\n", 712 | " )\n", 713 | " (fc): Linear(in_features=300, out_features=2, bias=True)\n", 714 | " (dropout): Dropout(p=0.5, inplace=False)\n", 715 | ")\n", 716 | "# of trainable params: 7,861,502\n", 717 | "# of non-trainable params: 0\n" 718 | ] 719 | } 720 | ], 721 | "source": [ 722 | "model = CNNTextClassifier(VOCAB_SIZE, N_CLASSES, EMBED_SIZE, FILTER_SIZES, N_FILTERS, PAD_IDX, DROPOUT_RATE)\n", 723 | "# load the pre-trained GloVe embedding\n", 724 | "model.load_pretrained_embeddings(review.vocab.vectors, fine_tune=True)\n", 725 | "model = model.to(device)\n", 726 | "\n", 727 | "model_summary(model)" 728 | ] 729 | }, 730 | { 731 | "cell_type": "code", 732 | "execution_count": 55, 733 | "metadata": { 734 | "executionInfo": { 735 | "elapsed": 20, 736 | "status": "ok", 737 | "timestamp": 1622651034651, 738 | "user": { 739 | "displayName": "Jay Gala", 740 | "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Ghr20IAxple1q2jggw00r147YTV5sjlqKrRANzYVA=s64", 741 | "userId": "11525158385229787114" 742 | }, 743 | "user_tz": -330 744 | }, 745 | "id": "vW7sfVU1aPVi" 746 | }, 747 | "outputs": [], 748 | "source": [ 749 | "optimizer = optim.Adadelta(model.parameters(), lr=1e-1)\n", 750 | "scheduler = optim.lr_scheduler.ReduceLROnPlateau(\n", 751 | " optimizer, patience=0, threshold=0.001, mode='max'\n", 752 | ")" 753 | ] 754 | }, 755 | { 756 | "cell_type": "code", 757 | "execution_count": 56, 758 | "metadata": { 759 | "executionInfo": { 760 | "elapsed": 21, 761 | "status": "ok", 762 | "timestamp": 1622651034655, 763 | "user": { 764 | "displayName": "Jay Gala", 765 | "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Ghr20IAxple1q2jggw00r147YTV5sjlqKrRANzYVA=s64", 766 | "userId": "11525158385229787114" 767 | }, 768 | "user_tz": -330 769 | }, 770 | "id": "pcaK7UVSkiv6" 771 | }, 772 | "outputs": [], 773 | "source": [ 774 | "criterion = nn.CrossEntropyLoss()" 775 | ] 776 | }, 777 | { 778 | "cell_type": "code", 779 | "execution_count": null, 780 | "metadata": { 781 | "id": "p7VOszRie17P" 782 | }, 783 | "outputs": [], 784 | "source": [ 785 | "best_acc = 0\n", 786 | "es_patience = 3\n", 787 | "patience = 0\n", 788 | "model_path = 'model.pth'\n", 789 | "\n", 790 | "\n", 791 | "for epoch in range(0, N_EPOCHS + 1):\n", 792 | " # one epoch training\n", 793 | " train_acc, train_loss = train_fn(model, train_iterator, optimizer, criterion)\n", 794 | " \n", 795 | " # one epoch validation\n", 796 | " valid_acc, valid_loss = eval_fn(model, valid_iterator, criterion)\n", 797 | " \n", 798 | " print(f'Epoch: {epoch}, Train Accuracy: {train_acc * 100:.2f}%, Train loss: {train_loss:.4f}, Valid Accuracy: {valid_acc * 100:.2f}%,, Valid Loss: {valid_loss:.4f}')\n", 799 | " \n", 800 | " scheduler.step(valid_acc)\n", 801 | "\n", 802 | " is_best = valid_acc > best_acc\n", 803 | " if is_best:\n", 804 | " print(f'Best accuracy improved ({best_acc * 100:.2f}% -> {valid_acc * 100:.2f}%). Saving Model!')\n", 805 | " best_acc = valid_acc\n", 806 | " patience = 0\n", 807 | " torch.save(model.state_dict(), model_path)\n", 808 | " else:\n", 809 | " patience += 1\n", 810 | " print(f'Early stopping counter: {patience} out of {es_patience}')\n", 811 | " if patience == es_patience:\n", 812 | " print(f'Early stopping! Best accuracy: {best_acc * 100:.2f}%')\n", 813 | " break" 814 | ] 815 | }, 816 | { 817 | "cell_type": "code", 818 | "execution_count": 58, 819 | "metadata": { 820 | "colab": { 821 | "base_uri": "https://localhost:8080/", 822 | "height": 101, 823 | "referenced_widgets": [ 824 | "0e8d20a7c37f41ec8be9bedf0b24c5ad", 825 | "77cffdfd79db46d098ff3e85694af708", 826 | "5092dec0198d4ddca6ec280a8efe20fa", 827 | "97565fb53ca747baac3a3a795dbb60df", 828 | "76a92c955bba41f4862918d88c888376", 829 | "57835188d57b40b3b7ab7b32167d5f20", 830 | "7ddcff75499442dc8e472c6121eed9c2", 831 | "9ed81c02af6047dea211c5483ba358bc" 832 | ] 833 | }, 834 | "executionInfo": { 835 | "elapsed": 4116, 836 | "status": "ok", 837 | "timestamp": 1622651418336, 838 | "user": { 839 | "displayName": "Jay Gala", 840 | "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Ghr20IAxple1q2jggw00r147YTV5sjlqKrRANzYVA=s64", 841 | "userId": "11525158385229787114" 842 | }, 843 | "user_tz": -330 844 | }, 845 | "id": "xkLMlMCeuRJv", 846 | "outputId": "8140b6d9-96aa-491d-dec5-bfa51ba4a05a" 847 | }, 848 | "outputs": [ 849 | { 850 | "name": "stdout", 851 | "output_type": "stream", 852 | "text": [ 853 | "Evaluating the model on test data ...\n" 854 | ] 855 | }, 856 | { 857 | "data": { 858 | "application/vnd.jupyter.widget-view+json": { 859 | "model_id": "0e8d20a7c37f41ec8be9bedf0b24c5ad", 860 | "version_major": 2, 861 | "version_minor": 0 862 | }, 863 | "text/plain": [ 864 | "HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))" 865 | ] 866 | }, 867 | "metadata": { 868 | "tags": [] 869 | }, 870 | "output_type": "display_data" 871 | }, 872 | { 873 | "name": "stdout", 874 | "output_type": "stream", 875 | "text": [ 876 | "\n", 877 | "Test Accuracy: 86.63%, Test loss: 0.3117\n" 878 | ] 879 | } 880 | ], 881 | "source": [ 882 | "# evaluate the model on test data\n", 883 | "model.load_state_dict(torch.load(model_path, map_location=device))\n", 884 | "\n", 885 | "print('Evaluating the model on test data ...')\n", 886 | "test_acc, test_loss = eval_fn(model, test_iterator, criterion)\n", 887 | "print(f'Test Accuracy: {test_acc * 100:.2f}%, Test loss: {test_loss:.4f}')" 888 | ] 889 | }, 890 | { 891 | "cell_type": "markdown", 892 | "metadata": { 893 | "id": "fZ5zIFfEuLrP" 894 | }, 895 | "source": [ 896 | "## Inferences\n" 897 | ] 898 | }, 899 | { 900 | "cell_type": "code", 901 | "execution_count": 59, 902 | "metadata": { 903 | "executionInfo": { 904 | "elapsed": 1540, 905 | "status": "ok", 906 | "timestamp": 1622651441453, 907 | "user": { 908 | "displayName": "Jay Gala", 909 | "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Ghr20IAxple1q2jggw00r147YTV5sjlqKrRANzYVA=s64", 910 | "userId": "11525158385229787114" 911 | }, 912 | "user_tz": -330 913 | }, 914 | "id": "kn5ENhvFwXc1" 915 | }, 916 | "outputs": [], 917 | "source": [ 918 | "nlp = spacy.load('en')\n", 919 | "\n", 920 | "def predict(model, text, min_len=5):\n", 921 | " model.eval()\n", 922 | "\n", 923 | " tokens = [token.text for token in nlp.tokenizer(text)]\n", 924 | "\n", 925 | " if len(tokens) < min_len:\n", 926 | " tokens += [review.pad_token] * (min_len - len(tokens))\n", 927 | " \n", 928 | " token_ids = [review.vocab.stoi.get(token, review.unk_token) for token in tokens]\n", 929 | "\n", 930 | " token_ids = torch.tensor(token_ids, dtype=torch.long).to(device) # (seq_len)\n", 931 | " token_ids = token_ids.unsqueeze(0) # (1, seq_len)\n", 932 | "\n", 933 | " logits = model(token_ids)\n", 934 | "\n", 935 | " pred_probs = F.softmax(logits, dim=1).squeeze(0)\n", 936 | "\n", 937 | " return pred_probs" 938 | ] 939 | }, 940 | { 941 | "cell_type": "code", 942 | "execution_count": 67, 943 | "metadata": { 944 | "colab": { 945 | "base_uri": "https://localhost:8080/" 946 | }, 947 | "executionInfo": { 948 | "elapsed": 412, 949 | "status": "ok", 950 | "timestamp": 1622651694922, 951 | "user": { 952 | "displayName": "Jay Gala", 953 | "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Ghr20IAxple1q2jggw00r147YTV5sjlqKrRANzYVA=s64", 954 | "userId": "11525158385229787114" 955 | }, 956 | "user_tz": -330 957 | }, 958 | "id": "ypAw26oXk1sY", 959 | "outputId": "c38e85c8-23b0-4c95-e1ff-bdc20f1e7bcf" 960 | }, 961 | "outputs": [ 962 | { 963 | "name": "stdout", 964 | "output_type": "stream", 965 | "text": [ 966 | "The movie was great!\n", 967 | "positive: 73.38%, negative: 26.62%\n" 968 | ] 969 | } 970 | ], 971 | "source": [ 972 | "sample_text = 'The movie was great!'\n", 973 | "print(sample_text)\n", 974 | "\n", 975 | "pred_probs = predict(model, sample_text)\n", 976 | "print(f'positive: {pred_probs[0] * 100:.2f}%, negative: {(pred_probs[1]) * 100:.2f}%')" 977 | ] 978 | }, 979 | { 980 | "cell_type": "code", 981 | "execution_count": 68, 982 | "metadata": { 983 | "colab": { 984 | "base_uri": "https://localhost:8080/" 985 | }, 986 | "executionInfo": { 987 | "elapsed": 3, 988 | "status": "ok", 989 | "timestamp": 1622651695289, 990 | "user": { 991 | "displayName": "Jay Gala", 992 | "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Ghr20IAxple1q2jggw00r147YTV5sjlqKrRANzYVA=s64", 993 | "userId": "11525158385229787114" 994 | }, 995 | "user_tz": -330 996 | }, 997 | "id": "DHW1vcRvk7lR", 998 | "outputId": "7526541b-709b-4600-9520-2eda86e66fd9" 999 | }, 1000 | "outputs": [ 1001 | { 1002 | "name": "stdout", 1003 | "output_type": "stream", 1004 | "text": [ 1005 | "The movie was okay.\n", 1006 | "positive: 34.23%, negative: 65.77%\n" 1007 | ] 1008 | } 1009 | ], 1010 | "source": [ 1011 | "sample_text = 'The movie was okay.'\n", 1012 | "print(sample_text)\n", 1013 | "\n", 1014 | "pred_probs = predict(model, sample_text)\n", 1015 | "print(f'positive: {pred_probs[0] * 100:.2f}%, negative: {(pred_probs[1]) * 100:.2f}%')" 1016 | ] 1017 | }, 1018 | { 1019 | "cell_type": "code", 1020 | "execution_count": 69, 1021 | "metadata": { 1022 | "colab": { 1023 | "base_uri": "https://localhost:8080/" 1024 | }, 1025 | "executionInfo": { 1026 | "elapsed": 17, 1027 | "status": "ok", 1028 | "timestamp": 1622651695975, 1029 | "user": { 1030 | "displayName": "Jay Gala", 1031 | "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Ghr20IAxple1q2jggw00r147YTV5sjlqKrRANzYVA=s64", 1032 | "userId": "11525158385229787114" 1033 | }, 1034 | "user_tz": -330 1035 | }, 1036 | "id": "WdjKB5_8yWz-", 1037 | "outputId": "d6c58f33-ceb9-44ab-9448-c28b5295617e" 1038 | }, 1039 | "outputs": [ 1040 | { 1041 | "name": "stdout", 1042 | "output_type": "stream", 1043 | "text": [ 1044 | "The movie was terrible...\n", 1045 | "positive: 9.10%, negative: 90.90%\n" 1046 | ] 1047 | } 1048 | ], 1049 | "source": [ 1050 | "sample_text = 'The movie was terrible...'\n", 1051 | "print(sample_text)\n", 1052 | "\n", 1053 | "pred_probs = predict(model, sample_text)\n", 1054 | "print(f'positive: {pred_probs[0] * 100:.2f}%, negative: {(pred_probs[1]) * 100:.2f}%')" 1055 | ] 1056 | }, 1057 | { 1058 | "cell_type": "markdown", 1059 | "metadata": { 1060 | "id": "73M2zaJ4sjdX" 1061 | }, 1062 | "source": [ 1063 | "## References\n", 1064 | "\n", 1065 | "1. [Convolutional Neural Networks for Sentence Classification](https://arxiv.org/pdf/1408.5882.pdf)\n", 1066 | "2. [A Sensitivity Analysis of (and Practitioners’ Guide to) Convolutional Neural Networks for Sentence Classification](https://arxiv.org/pdf/1510.03820.pdf)\n", 1067 | "3. [Official Theano Implementation by Yoon Kim](https://github.com/yoonkim/CNN_sentence)\n", 1068 | "4. [PyTorch Sentiment Analysis by Ben Trevet](https://github.com/bentrevett/pytorch-sentiment-analysis)\n" 1069 | ] 1070 | } 1071 | ], 1072 | "metadata": { 1073 | "accelerator": "GPU", 1074 | "colab": { 1075 | "authorship_tag": "ABX9TyO+P5BK60lx1rX+gFJHgQ+e", 1076 | "collapsed_sections": [], 1077 | "name": "Convolutional Neural Networks for Sentence Classification.ipynb", 1078 | "provenance": [], 1079 | "toc_visible": true 1080 | }, 1081 | "kernelspec": { 1082 | "display_name": "Python 3", 1083 | "language": "python", 1084 | "name": "python3" 1085 | }, 1086 | "language_info": { 1087 | "codemirror_mode": { 1088 | "name": "ipython", 1089 | "version": 3 1090 | }, 1091 | "file_extension": ".py", 1092 | "mimetype": "text/x-python", 1093 | "name": "python", 1094 | "nbconvert_exporter": "python", 1095 | "pygments_lexer": "ipython3", 1096 | "version": "3.8.5" 1097 | }, 1098 | "widgets": { 1099 | "application/vnd.jupyter.widget-state+json": { 1100 | "0e8d20a7c37f41ec8be9bedf0b24c5ad": { 1101 | "model_module": "@jupyter-widgets/controls", 1102 | "model_name": "HBoxModel", 1103 | "state": { 1104 | "_dom_classes": [], 1105 | "_model_module": "@jupyter-widgets/controls", 1106 | "_model_module_version": "1.5.0", 1107 | "_model_name": "HBoxModel", 1108 | "_view_count": null, 1109 | "_view_module": "@jupyter-widgets/controls", 1110 | "_view_module_version": "1.5.0", 1111 | "_view_name": "HBoxView", 1112 | "box_style": "", 1113 | "children": [ 1114 | "IPY_MODEL_5092dec0198d4ddca6ec280a8efe20fa", 1115 | "IPY_MODEL_97565fb53ca747baac3a3a795dbb60df" 1116 | ], 1117 | "layout": "IPY_MODEL_77cffdfd79db46d098ff3e85694af708" 1118 | } 1119 | }, 1120 | "5092dec0198d4ddca6ec280a8efe20fa": { 1121 | "model_module": "@jupyter-widgets/controls", 1122 | "model_name": "FloatProgressModel", 1123 | "state": { 1124 | "_dom_classes": [], 1125 | "_model_module": "@jupyter-widgets/controls", 1126 | "_model_module_version": "1.5.0", 1127 | "_model_name": "FloatProgressModel", 1128 | "_view_count": null, 1129 | "_view_module": "@jupyter-widgets/controls", 1130 | "_view_module_version": "1.5.0", 1131 | "_view_name": "ProgressView", 1132 | "bar_style": "success", 1133 | "description": "100%", 1134 | "description_tooltip": null, 1135 | "layout": "IPY_MODEL_57835188d57b40b3b7ab7b32167d5f20", 1136 | "max": 391, 1137 | "min": 0, 1138 | "orientation": "horizontal", 1139 | "style": "IPY_MODEL_76a92c955bba41f4862918d88c888376", 1140 | "value": 391 1141 | } 1142 | }, 1143 | "57835188d57b40b3b7ab7b32167d5f20": { 1144 | "model_module": "@jupyter-widgets/base", 1145 | "model_name": "LayoutModel", 1146 | "state": { 1147 | "_model_module": "@jupyter-widgets/base", 1148 | "_model_module_version": "1.2.0", 1149 | "_model_name": "LayoutModel", 1150 | "_view_count": null, 1151 | "_view_module": "@jupyter-widgets/base", 1152 | "_view_module_version": "1.2.0", 1153 | "_view_name": "LayoutView", 1154 | "align_content": null, 1155 | "align_items": null, 1156 | "align_self": null, 1157 | "border": null, 1158 | "bottom": null, 1159 | "display": null, 1160 | "flex": null, 1161 | "flex_flow": null, 1162 | "grid_area": null, 1163 | "grid_auto_columns": null, 1164 | "grid_auto_flow": null, 1165 | "grid_auto_rows": null, 1166 | "grid_column": null, 1167 | "grid_gap": null, 1168 | "grid_row": null, 1169 | "grid_template_areas": null, 1170 | "grid_template_columns": null, 1171 | "grid_template_rows": null, 1172 | "height": null, 1173 | "justify_content": null, 1174 | "justify_items": null, 1175 | "left": null, 1176 | "margin": null, 1177 | "max_height": null, 1178 | "max_width": null, 1179 | "min_height": null, 1180 | "min_width": null, 1181 | "object_fit": null, 1182 | "object_position": null, 1183 | "order": null, 1184 | "overflow": null, 1185 | "overflow_x": null, 1186 | "overflow_y": null, 1187 | "padding": null, 1188 | "right": null, 1189 | "top": null, 1190 | "visibility": null, 1191 | "width": null 1192 | } 1193 | }, 1194 | "76a92c955bba41f4862918d88c888376": { 1195 | "model_module": "@jupyter-widgets/controls", 1196 | "model_name": "ProgressStyleModel", 1197 | "state": { 1198 | "_model_module": "@jupyter-widgets/controls", 1199 | "_model_module_version": "1.5.0", 1200 | "_model_name": "ProgressStyleModel", 1201 | "_view_count": null, 1202 | "_view_module": "@jupyter-widgets/base", 1203 | "_view_module_version": "1.2.0", 1204 | "_view_name": "StyleView", 1205 | "bar_color": null, 1206 | "description_width": "initial" 1207 | } 1208 | }, 1209 | "77cffdfd79db46d098ff3e85694af708": { 1210 | "model_module": "@jupyter-widgets/base", 1211 | "model_name": "LayoutModel", 1212 | "state": { 1213 | "_model_module": "@jupyter-widgets/base", 1214 | "_model_module_version": "1.2.0", 1215 | "_model_name": "LayoutModel", 1216 | "_view_count": null, 1217 | "_view_module": "@jupyter-widgets/base", 1218 | "_view_module_version": "1.2.0", 1219 | "_view_name": "LayoutView", 1220 | "align_content": null, 1221 | "align_items": null, 1222 | "align_self": null, 1223 | "border": null, 1224 | "bottom": null, 1225 | "display": null, 1226 | "flex": null, 1227 | "flex_flow": null, 1228 | "grid_area": null, 1229 | "grid_auto_columns": null, 1230 | "grid_auto_flow": null, 1231 | "grid_auto_rows": null, 1232 | "grid_column": null, 1233 | "grid_gap": null, 1234 | "grid_row": null, 1235 | "grid_template_areas": null, 1236 | "grid_template_columns": null, 1237 | "grid_template_rows": null, 1238 | "height": null, 1239 | "justify_content": null, 1240 | "justify_items": null, 1241 | "left": null, 1242 | "margin": null, 1243 | "max_height": null, 1244 | "max_width": null, 1245 | "min_height": null, 1246 | "min_width": null, 1247 | "object_fit": null, 1248 | "object_position": null, 1249 | "order": null, 1250 | "overflow": null, 1251 | "overflow_x": null, 1252 | "overflow_y": null, 1253 | "padding": null, 1254 | "right": null, 1255 | "top": null, 1256 | "visibility": null, 1257 | "width": null 1258 | } 1259 | }, 1260 | "7ddcff75499442dc8e472c6121eed9c2": { 1261 | "model_module": "@jupyter-widgets/controls", 1262 | "model_name": "DescriptionStyleModel", 1263 | "state": { 1264 | "_model_module": "@jupyter-widgets/controls", 1265 | "_model_module_version": "1.5.0", 1266 | "_model_name": "DescriptionStyleModel", 1267 | "_view_count": null, 1268 | "_view_module": "@jupyter-widgets/base", 1269 | "_view_module_version": "1.2.0", 1270 | "_view_name": "StyleView", 1271 | "description_width": "" 1272 | } 1273 | }, 1274 | "97565fb53ca747baac3a3a795dbb60df": { 1275 | "model_module": "@jupyter-widgets/controls", 1276 | "model_name": "HTMLModel", 1277 | "state": { 1278 | "_dom_classes": [], 1279 | "_model_module": "@jupyter-widgets/controls", 1280 | "_model_module_version": "1.5.0", 1281 | "_model_name": "HTMLModel", 1282 | "_view_count": null, 1283 | "_view_module": "@jupyter-widgets/controls", 1284 | "_view_module_version": "1.5.0", 1285 | "_view_name": "HTMLView", 1286 | "description": "", 1287 | "description_tooltip": null, 1288 | "layout": "IPY_MODEL_9ed81c02af6047dea211c5483ba358bc", 1289 | "placeholder": "​", 1290 | "style": "IPY_MODEL_7ddcff75499442dc8e472c6121eed9c2", 1291 | "value": " 391/391 [00:25<00:00, 15.09it/s]" 1292 | } 1293 | }, 1294 | "9ed81c02af6047dea211c5483ba358bc": { 1295 | "model_module": "@jupyter-widgets/base", 1296 | "model_name": "LayoutModel", 1297 | "state": { 1298 | "_model_module": "@jupyter-widgets/base", 1299 | "_model_module_version": "1.2.0", 1300 | "_model_name": "LayoutModel", 1301 | "_view_count": null, 1302 | "_view_module": "@jupyter-widgets/base", 1303 | "_view_module_version": "1.2.0", 1304 | "_view_name": "LayoutView", 1305 | "align_content": null, 1306 | "align_items": null, 1307 | "align_self": null, 1308 | "border": null, 1309 | "bottom": null, 1310 | "display": null, 1311 | "flex": null, 1312 | "flex_flow": null, 1313 | "grid_area": null, 1314 | "grid_auto_columns": null, 1315 | "grid_auto_flow": null, 1316 | "grid_auto_rows": null, 1317 | "grid_column": null, 1318 | "grid_gap": null, 1319 | "grid_row": null, 1320 | "grid_template_areas": null, 1321 | "grid_template_columns": null, 1322 | "grid_template_rows": null, 1323 | "height": null, 1324 | "justify_content": null, 1325 | "justify_items": null, 1326 | "left": null, 1327 | "margin": null, 1328 | "max_height": null, 1329 | "max_width": null, 1330 | "min_height": null, 1331 | "min_width": null, 1332 | "object_fit": null, 1333 | "object_position": null, 1334 | "order": null, 1335 | "overflow": null, 1336 | "overflow_x": null, 1337 | "overflow_y": null, 1338 | "padding": null, 1339 | "right": null, 1340 | "top": null, 1341 | "visibility": null, 1342 | "width": null 1343 | } 1344 | } 1345 | } 1346 | } 1347 | }, 1348 | "nbformat": 4, 1349 | "nbformat_minor": 4 1350 | } --------------------------------------------------------------------------------