├── .gitignore
├── LICENSE
├── README.md
├── chapter01
├── Basic_Machine_Learning_in_Python_Introduction_to_Sklearn,_Keras_and_Pytorch.ipynb
├── Dockerfile
├── Modeling_and_Predicting_in_Keras.ipynb
├── README.md
├── environment.yml
├── perceptron.dot
├── perceptron.png
└── run_notebook.sh
├── chapter02
├── Battling Algorithmic Bias.ipynb
├── Forecasting CO2 Time Series.ipynb
├── Live Decisioning Customer Values.ipynb
├── Predicting House Prices in Pytorch.ipynb
└── Transforming Data in Scikit-Learn.ipynb
├── chapter03
├── Clustering market segments.ipynb
├── Discovering anomalities.ipynb
├── Recommending_products.ipynb
├── Representing for similarity search.ipynb
└── Spotting fraudster communities.ipynb
├── chapter04
├── Diagnosing a disease.ipynb
├── Estimating customer lifetime-value.ipynb
├── Predicting stock prices with confidence.ipynb
└── stopping_credit_default.ipynb
├── chapter05
├── Finding the shortest path.ipynb
├── Making decisions based on knowledge.ipynb
├── Simulating the spread of a disease.ipynb
├── Solving the n-queens problem.ipynb
├── Writing a chess engine with Monte Carlo tree search.ipynb
├── pso_it0.png
├── pso_it1322.png
├── pso_it1323-1.png
├── pso_it3-1.png
├── pso_it32-1.png
├── skipgram.dot
├── skipgram.png
└── solving-n-queens.md
├── chapter06
├── Controling a cartpole.ipynb
├── Optimizing a website.ipynb
└── Playing blackjack.ipynb
├── chapter07
├── Encoding_images_and_style.ipynb
├── Generating_images.ipynb
└── Recognizing clothing items.ipynb
├── chapter08
└── Localizing_objects.ipynb
├── chapter09
├── Generating_melodies.ipynb
├── Synthesizing_speech.ipynb
└── recognizing_voice_commands.ipynb
├── chapter10
├── Chatting_to_users.ipynb
├── Classifying_newsgroups.ipynb
├── Translating_from_English_to_German.ipynb
└── Writing_a_popular_novel.ipynb
└── chapter11
├── Securing Models Against Attack.ipynb
├── Serving a model for live-decisioning.ipynb
└── visualizing_model_results.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 | .ipynb_checkpoints
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Packt
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | # Artificial Intelligence with Python Cookbook
5 |
6 |
7 |
8 | This is the code repository for [Artificial Intelligence with Python Cookbook](https://www.packtpub.com/product/artificial-intelligence-with-python-cookbook/9781789133967), published by Packt.
9 |
10 | **Practical recipes for next-generation deep learning and neural networks using TensorFlow and PyTorch**
11 |
12 | ## What is this book about?
13 | With artificial intelligence (AI) systems, we can develop goal-driven agents to automate problem-solving. This involves predicting and classifying the available data and training agents to execute tasks successfully. This book will help you to solve complex AI problems using practical recipes.
14 |
15 | This book covers the following exciting features:
16 | * Implement data preprocessing steps and optimize model hyperparameters
17 | * Work with large amounts of data using distributed and parallel computing techniques
18 | * Get to grips with representational learning from images using InfoGAN
19 | * Delve into deep probabilistic modeling with a Bayesian network
20 | * Create your own artwork using adversarial neural networks
21 |
22 | If you feel this book is for you, get your [copy](https://www.amazon.com/dp/1789133963) today!
23 |
24 |
26 |
27 |
28 | ## Instructions and Navigations
29 | All of the code is organized into folders. For example, Chapter02.
30 |
31 | The code will look like the following:
32 | ```
33 | from sklearn.datasets import fetch_openml
34 | data = fetch_openml(data_id=42165, as_frame=True)
35 | ```
36 |
37 | **Following is what you need for this book:**
38 | This AI book is for Python developers, data scientists, machine learning developers, and deep learning practitioners who want to learn how to build artificial intelligence solutions with easy-to-follow recipes. If you are looking for state-of-the-art solutions to perform different machine learning tasks in various use cases, this book is for you. Basic working knowledge of Python programming language and machine learning concepts will help you to work with the code.
39 |
40 | With the following software and hardware list you can run all code files present in the book (Chapter 1-11).
41 |
42 | ### Software and Hardware List
43 |
44 | | Chapter | Software required | OS required |
45 | | -------- | ------------------------------------| -----------------------------------|
46 | | 1 | Python 3.6 or later | Windows, Mac OS X, and Linux (Any) |
47 | | 2 |TensorFlow 2.0 or later | Windows, Mac OS X, and Linux (Any) |
48 | | 3 | PyTorch 1.6 or later | Windows, Mac OS X, and Linux (Any) |
49 | | 4 |Pandas 1.0 or later | Windows, Mac OS X, and Linux (Any) |
50 | | 5 | Scikit-learn 0.22.0 or later | Windows, Mac OS X, and Linux (Any) |
51 |
52 |
53 | We also provide a PDF file that has color images of the screenshots/diagrams used in this book. [Click here to download it](https://static.packt-cdn.com/downloads/9781789133967_ColorImages.pdf).
54 |
55 | ### Related products
56 | * Hands-On Artificial Intelligence for Banking [[Packt]](https://www.packtpub.com/product/hands-on-artificial-intelligence-for-banking/9781788830782) [[Amazon]](https://www.amazon.com/dp/1788830784)
57 |
58 | * Artificial Intelligence with Python - Second Edition [[Packt]](https://www.packtpub.com/product/artificial-intelligence-with-python-second-edition/9781839219535) [[Amazon]](https://www.amazon.com/dp/183921953X)
59 |
60 | ## Get to Know the Author
61 | **Ben Auffarth**
62 | is a full-stack data scientist with more than 15 years of work experience. With a background and Ph.D. in computational and cognitive neuroscience, he has designed and conducted wet lab experiments on cell cultures, analyzed experiments with terabytes of data, run brain models on IBM supercomputers with up to 64k cores, built production systems processing hundreds of thousands of transactions per day, and trained neural networks on millions of text documents. He resides in West London with his family, where you might find him in a playground with his young son. He co-founded and is the former president of Data Science Speakers, London.
63 |
64 | ### Suggestions and Feedback
65 | [Click here](https://docs.google.com/forms/d/e/1FAIpQLSdy7dATC6QmEL81FIUuymZ0Wy9vH1jHkvpY57OiMeKGqib_Ow/viewform) if you have any feedback or suggestions.
66 | ### Download a free PDF
67 |
68 | If you have already purchased a print or Kindle version of this book, you can get a DRM-free PDF version at no cost. Simply click on the link to claim your free PDF.
69 |
https://packt.link/free-ebook/9781789133967
--------------------------------------------------------------------------------
/chapter01/Dockerfile:
--------------------------------------------------------------------------------
1 | #Dockerfile
2 | FROM continuumio/miniconda3
3 | ADD ./run_notebook.sh /usr/local/bin/
4 | RUN chmod +x /usr/local/bin/run_notebook.sh
5 | && chmod -R 777 /usr/local/*
6 | RUN apt update -qq && apt install -y libgl1-mesa-glx libegl1-mesa libxrandr2 libxrandr2 libxss1 libxcursor1 libxcomposite1 libasound2 libxi6 libxtst6 build-essential vim curl wget libhdf5-dev libhdf5-serial-dev cython3 python-h5py
7 | RUN useradd -m docker -s /bin/bash -p '*' && chown -R 1000:1000 /opt/conda
8 | USER docker
9 | RUN conda -f /home/docker/tabiri/environment.yml && echo "source activate env" > ~/.bashrc
10 | ENV PATH /opt/conda/envs/env/bin:$PATH
11 | CMD /bin/bash /usr/local/bin/run_notebook.sh
--------------------------------------------------------------------------------
/chapter01/README.md:
--------------------------------------------------------------------------------
1 | # Hosting a Jupyter environment with Docker
2 | It is out of the scope of this book to discuss all the details of Docker, but briefly, Docker is a tool that allows you to set up virtualized servers that are called containers.
3 |
4 | The main parts of the setup process are several files:
5 | - Dockerfile - a text file that lists the basic setup instructions for the virtual machine
6 | - environment.yml - a file listing the python dependencies as installed by conda
7 | - run_notebook.sh - a script to run your notebook from inside the container
8 |
9 | Add the following three files to an empty directory:
10 | ## Dockerfile
11 |
12 | ```bash
13 | #Dockerfile
14 | FROM continuumio/miniconda3
15 | ADD ./run_notebook.sh /usr/local/bin/
16 | RUN chmod +x /usr/local/bin/run_notebook.sh
17 | && chmod -R 777 /usr/local/*
18 | RUN apt update -qq && apt install -y libgl1-mesa-glx libegl1-mesa libxrandr2 libxrandr2 libxss1 libxcursor1 libxcomposite1 libasound2 libxi6 libxtst6 build-essential vim curl wget libhdf5-dev libhdf5-serial-dev cython3 python-h5py
19 | RUN useradd -m docker -s /bin/bash -p '*' && chown -R 1000:1000 /opt/conda
20 | USER docker
21 | RUN conda -f /home/docker/tabiri/environment.yml && echo "source activate env" > ~/.bashrc
22 | ENV PATH /opt/conda/envs/env/bin:$PATH
23 | CMD /bin/bash /usr/local/bin/run_notebook.sh
24 | ```
25 |
26 | ## environment.yml
27 | ```bash
28 | # environment.yml
29 | name: env
30 | channels:
31 | - defaults
32 | dependencies:
33 | - pip=19.3.1=py37_0
34 | - protobuf=3.9.2=py37he6710b0_0
35 | - python=3.7.0=h6e4f718_3
36 | - setuptools=41.6.0=py37_0
37 | - sqlite=3.30.1=h7b6447c_0
38 | - tensorboard=2.0.0
39 | - tensorflow=2.0.0
40 | - tensorflow-base=2.0.0
41 | - tensorflow-estimator=2.0.0
42 | - pip:
43 | - torch==1.3.1
44 | - numpy==1.18.0
45 | - pandas==0.25.3
46 | - scikit-learn==0.21.3
47 | - scipy==1.2.0
48 | prefix: /opt/conda/envs/env
49 | ```
50 |
51 | ## run_notebook.sh
52 | ```bash
53 | #!/bin/bash
54 | # run_notebook.sh
55 | ## Don't attempt to run if we are not root
56 | ## EUID stands for Effective User ID
57 | if [ "$EUID" -ne 0 ]; then
58 | echo "Please run as root"
59 | exit
60 | fi
61 |
62 | ## Set defaults for environmental variables in case they are undefined
63 | USER=${USER:=jupyter}
64 | PASS=${PASS:=jupyter}
65 | USERID=${USERID:=1000}
66 | USERGID=${USERGID:=1000}
67 | CONFIG=".jupyter/jupyter_notebook_config.py"
68 |
69 | if [ "$USERID" -ne 0 ]; then
70 | echo "creating new $USER with UID $USERID"
71 | groupadd -g $USERGID $USER
72 | useradd -m -u $USERID -g $USERGID $USER
73 | echo "$USER added to sudoers"
74 | fi
75 | cd /home/$USER
76 | mkdir -p .jupyter
77 | /bin/cat <$CONFIG
78 | from notebook.auth import passwd
79 | c = get_config()
80 | passw = passwd('$PASS')
81 | c.NotebookApp.password = passw
82 | c.IPKernelApp.pylab = 'inline'
83 | c.NotebookManager.save_script = True
84 | c.NotebookApp.open_browser = False
85 | c.NotebookApp.port = 9999
86 | c.NotebookApp.ip = '0.0.0.0'
87 | # avoid restart on slow connections:
88 | c.NotebookApp.tornado_settings = {'kernel_info_timeout': 60}
89 | EOF
90 | chown -R $USER:$USER .jupyter
91 | su $USER -c "jupyter notebook"
92 | Then build your container and run it:
93 | docker build -t jupyter .
94 | Once that's finished you can run your container like this:
95 | PORT=5050 docker run -d \
96 | --runtime=nvidia \ # optionally: if you rely on the nvidia docker binaries
97 | --name "jupyter_${USER}_${PORT}" \
98 | -p $PORT:9999 \
99 | -e USER=$USER \
100 | -e USERGID=$(id -g $1) \
101 | -e USERID=$(id -u $1) \
102 | -e PASS=$PASS jupyter \
103 | /usr/local/bin/run_notebook.sh
104 | ```
105 |
106 | You can add mount parameters to the run command with the -v option. This is useful if you want the docker container to share directories with the host machine. Otherwise, you can copy files using the docker-copy command
--------------------------------------------------------------------------------
/chapter01/environment.yml:
--------------------------------------------------------------------------------
1 | name: env
2 | channels:
3 | - defaults
4 | dependencies:
5 | - pip=19.3.1=py37_0
6 | - protobuf=3.9.2=py37he6710b0_0
7 | - python=3.7.0=h6e4f718_3
8 | - setuptools=41.6.0=py37_0
9 | - sqlite=3.30.1=h7b6447c_0
10 | - tensorboard=2.0.0
11 | - tensorflow=2.0.0
12 | - tensorflow-base=2.0.0
13 | - tensorflow-estimator=2.0.0
14 | - pip:
15 | - torch==1.3.1
16 | - numpy==1.18.0
17 | - pandas==0.25.3
18 | - scikit-learn==0.21.3
19 | - scipy==1.2.0
20 | prefix: /opt/conda/envs/env
--------------------------------------------------------------------------------
/chapter01/perceptron.dot:
--------------------------------------------------------------------------------
1 | digraph G {
2 | # dot -Tpng perceptron.dot > perceptron.png
3 | rankdir=LR;
4 | in1 [label="x_0"];
5 | in2 [label="x_1"];
6 | in3 [label="x_2"];
7 | in4 [label="x_3"];
8 | in1 -> out [label="w_0"]
9 | in2 -> out [label="w_1"]
10 | in3 -> out [label="w_2"]
11 | in4 -> out [label="w_3"]
12 | }
--------------------------------------------------------------------------------
/chapter01/perceptron.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PacktPublishing/Artificial-Intelligence-with-Python-Cookbook/31f15f25c0bdc3286cbac6da75573ce9f069d25a/chapter01/perceptron.png
--------------------------------------------------------------------------------
/chapter01/run_notebook.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # run_notebook.sh
3 | ## Don't attempt to run if we are not root
4 | ## EUID stands for Effective User ID
5 | if [ "$EUID" -ne 0 ]; then
6 | echo "Please run as root"
7 | exit
8 | fi
9 |
10 | ## Set defaults for environmental variables in case they are undefined
11 | USER=${USER:=jupyter}
12 | PASS=${PASS:=jupyter}
13 | USERID=${USERID:=1000}
14 | USERGID=${USERGID:=1000}
15 | CONFIG=".jupyter/jupyter_notebook_config.py"
16 |
17 | if [ "$USERID" -ne 0 ]; then
18 | echo "creating new $USER with UID $USERID"
19 | groupadd -g $USERGID $USER
20 | useradd -m -u $USERID -g $USERGID $USER
21 | echo "$USER added to sudoers"
22 | fi
23 | cd /home/$USER
24 | mkdir -p .jupyter
25 | /bin/cat <$CONFIG
26 | from notebook.auth import passwd
27 | c = get_config()
28 | passw = passwd('$PASS')
29 | c.NotebookApp.password = passw
30 | c.IPKernelApp.pylab = 'inline'
31 | c.NotebookManager.save_script = True
32 | c.NotebookApp.open_browser = False
33 | c.NotebookApp.port = 9999
34 | c.NotebookApp.ip = '0.0.0.0'
35 | # avoid restart on slow connections:
36 | c.NotebookApp.tornado_settings = {'kernel_info_timeout': 60}
37 | EOF
38 | chown -R $USER:$USER .jupyter
39 | su $USER -c "jupyter notebook"
40 | Then build your container and run it:
41 | docker build -t jupyter .
42 | Once that's finished you can run your container like this:
43 | PORT=5050 docker run -d \
44 | --runtime=nvidia \ # optionally: if you rely on the nvidia docker binaries
45 | --name "jupyter_${USER}_${PORT}" \
46 | -p $PORT:9999 \
47 | -e USER=$USER \
48 | -e USERGID=$(id -g $1) \
49 | -e USERID=$(id -u $1) \
50 | -e PASS=$PASS jupyter \
51 | /usr/local/bin/run_notebook.sh
--------------------------------------------------------------------------------
/chapter02/Battling Algorithmic Bias.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 16,
6 | "metadata": {
7 | "colab": {
8 | "base_uri": "https://localhost:8080/",
9 | "height": 224
10 | },
11 | "colab_type": "code",
12 | "id": "UF3yuopkSW0A",
13 | "outputId": "9783c242-3e97-4ef0-a704-146b966894a6"
14 | },
15 | "outputs": [
16 | {
17 | "name": "stdout",
18 | "output_type": "stream",
19 | "text": [
20 | "--2020-08-06 10:20:39-- https://raw.githubusercontent.com/propublica/compas-analysis/master/compas-scores-two-years.csv\n",
21 | "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 151.101.0.133, 151.101.64.133, 151.101.128.133, ...\n",
22 | "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|151.101.0.133|:443... connected.\n",
23 | "HTTP request sent, awaiting response... 200 OK\n",
24 | "Length: 2546489 (2.4M) [text/plain]\n",
25 | "Saving to: ‘compas-scores-two-years.csv’\n",
26 | "\n",
27 | "\r",
28 | " compas-sc 0%[ ] 0 --.-KB/s \r",
29 | "compas-scores-two-y 100%[===================>] 2.43M 14.4MB/s in 0.2s \n",
30 | "\n",
31 | "2020-08-06 10:20:40 (14.4 MB/s) - ‘compas-scores-two-years.csv’ saved [2546489/2546489]\n",
32 | "\n"
33 | ]
34 | }
35 | ],
36 | "source": [
37 | "!wget https://raw.githubusercontent.com/propublica/compas-analysis/master/compas-scores-two-years.csv\n",
38 | "import pandas as pd\n",
39 | "date_cols = [\n",
40 | " 'compas_screening_date', 'c_offense_date',\n",
41 | " 'c_arrest_date', 'r_offense_date', \n",
42 | " 'vr_offense_date', 'screening_date',\n",
43 | " 'v_screening_date', 'c_jail_in',\n",
44 | " 'c_jail_out', 'dob', 'in_custody', \n",
45 | " 'out_custody'\n",
46 | "]\n",
47 | "data = pd.read_csv(\n",
48 | " 'compas-scores-two-years.csv',\n",
49 | " parse_dates=date_cols\n",
50 | ")\n"
51 | ]
52 | },
53 | {
54 | "cell_type": "code",
55 | "execution_count": 17,
56 | "metadata": {
57 | "colab": {},
58 | "colab_type": "code",
59 | "id": "sLLJgrpfSZtf"
60 | },
61 | "outputs": [],
62 | "source": [
63 | "import datetime\n",
64 | "indexes = data.compas_screening_date <= pd.Timestamp(datetime.date(2014, 4, 1))\n",
65 | "assert indexes.sum() == 6216\n",
66 | "data = data[indexes]\n"
67 | ]
68 | },
69 | {
70 | "cell_type": "code",
71 | "execution_count": 18,
72 | "metadata": {
73 | "colab": {
74 | "base_uri": "https://localhost:8080/",
75 | "height": 224
76 | },
77 | "colab_type": "code",
78 | "id": "1hGtkoiJSbTN",
79 | "outputId": "992c1197-5a3c-40d8-ed8a-56c37e01039e"
80 | },
81 | "outputs": [
82 | {
83 | "name": "stdout",
84 | "output_type": "stream",
85 | "text": [
86 | "Requirement already satisfied: category-encoders in /usr/local/lib/python3.6/dist-packages (2.2.2)\n",
87 | "Requirement already satisfied: scikit-learn>=0.20.0 in /usr/local/lib/python3.6/dist-packages (from category-encoders) (0.22.2.post1)\n",
88 | "Requirement already satisfied: statsmodels>=0.9.0 in /usr/local/lib/python3.6/dist-packages (from category-encoders) (0.10.2)\n",
89 | "Requirement already satisfied: scipy>=1.0.0 in /usr/local/lib/python3.6/dist-packages (from category-encoders) (1.4.1)\n",
90 | "Requirement already satisfied: numpy>=1.14.0 in /usr/local/lib/python3.6/dist-packages (from category-encoders) (1.18.5)\n",
91 | "Requirement already satisfied: patsy>=0.5.1 in /usr/local/lib/python3.6/dist-packages (from category-encoders) (0.5.1)\n",
92 | "Requirement already satisfied: pandas>=0.21.1 in /usr/local/lib/python3.6/dist-packages (from category-encoders) (1.0.5)\n",
93 | "Requirement already satisfied: joblib>=0.11 in /usr/local/lib/python3.6/dist-packages (from scikit-learn>=0.20.0->category-encoders) (0.16.0)\n",
94 | "Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from patsy>=0.5.1->category-encoders) (1.15.0)\n",
95 | "Requirement already satisfied: python-dateutil>=2.6.1 in /usr/local/lib/python3.6/dist-packages (from pandas>=0.21.1->category-encoders) (2.8.1)\n",
96 | "Requirement already satisfied: pytz>=2017.2 in /usr/local/lib/python3.6/dist-packages (from pandas>=0.21.1->category-encoders) (2018.9)\n"
97 | ]
98 | }
99 | ],
100 | "source": [
101 | "!pip install category-encoders"
102 | ]
103 | },
104 | {
105 | "cell_type": "code",
106 | "execution_count": 13,
107 | "metadata": {
108 | "colab": {},
109 | "colab_type": "code",
110 | "id": "nOWV7G65Gjqj"
111 | },
112 | "outputs": [],
113 | "source": [
114 | "def confusion_metrics(actual, scores, threshold):\n",
115 | " y_predicted = scores.apply(\n",
116 | " lambda x: x >= threshold\n",
117 | " ).values\n",
118 | " y_true = actual.values\n",
119 | " TP = (\n",
120 | " (y_true==y_predicted) & \n",
121 | " (y_predicted==1)\n",
122 | " ).astype(int)\n",
123 | " FP = (\n",
124 | " (y_true!=y_predicted) &\n",
125 | " (y_predicted==1)\n",
126 | " ).astype(int)\n",
127 | " TN = (\n",
128 | " (y_true==y_predicted) &\n",
129 | " (y_predicted==0)\n",
130 | " ).astype(int)\n",
131 | " FN = (\n",
132 | " (y_true!=y_predicted) &\n",
133 | " (y_predicted==0)\n",
134 | " ).astype(int)\n",
135 | " return TP, FP, TN, FN\n"
136 | ]
137 | },
138 | {
139 | "cell_type": "code",
140 | "execution_count": 14,
141 | "metadata": {
142 | "colab": {},
143 | "colab_type": "code",
144 | "id": "K3qiEKETSMOE"
145 | },
146 | "outputs": [],
147 | "source": [
148 | "def calculate_impacts(data, sensitive_column='race', recid_col='is_recid', score_col='decile_score.1', threshold=5.0):\n",
149 | " if sensitive_column == 'race':\n",
150 | " norm_group = 'Caucasian'\n",
151 | " elif sensitive_column == 'sex':\n",
152 | " norm_group = 'Male'\n",
153 | " else:\n",
154 | " raise ValueError('sensitive column not implemented')\n",
155 | " TP, FP, TN, FN = confusion_metrics(\n",
156 | " actual=data[recid_col],\n",
157 | " scores=data[score_col],\n",
158 | " threshold=threshold\n",
159 | " )\n",
160 | " impact = pd.DataFrame(\n",
161 | " data=np.column_stack([\n",
162 | " FP, TN, FN, TN,\n",
163 | " data[sensitive_column].values, \n",
164 | " data[recid_col].values,\n",
165 | " data[score_col].values / 10.0\n",
166 | " ]),\n",
167 | " columns=['FP', 'TP', 'FN', 'TN', 'sensitive', 'reoffend', 'score']\n",
168 | " ).groupby(by='sensitive').agg({\n",
169 | " 'reoffend': 'sum', 'score': 'sum',\n",
170 | " 'sensitive': 'count', \n",
171 | " 'FP': 'sum', 'TP': 'sum', 'FN': 'sum', 'TN': 'sum'\n",
172 | " }).rename(\n",
173 | " columns={'sensitive': 'N'}\n",
174 | " )\n",
175 | " impact['FPR'] = impact['FP'] / (impact['FP'] + impact['TN'])\n",
176 | " impact['FNR'] = impact['FN'] / (impact['FN'] + impact['TP'])\n",
177 | " impact['reoffend'] = impact['reoffend'] / impact['N']\n",
178 | " impact['score'] = impact['score'] / impact['N']\n",
179 | " impact['DFP'] = impact['FPR'] / impact.loc[norm_group, 'FPR']\n",
180 | " impact['DFN'] = impact['FNR'] / impact.loc[norm_group, 'FNR']\n",
181 | " return impact.drop(columns=['FP', 'TP', 'FN', 'TN'])\n"
182 | ]
183 | },
184 | {
185 | "cell_type": "code",
186 | "execution_count": 19,
187 | "metadata": {
188 | "colab": {},
189 | "colab_type": "code",
190 | "id": "GTE2xbhPSOw9"
191 | },
192 | "outputs": [],
193 | "source": [
194 | "from sklearn.feature_extraction.text import CountVectorizer\n",
195 | "from category_encoders.one_hot import OneHotEncoder\n",
196 | "from sklearn.model_selection import train_test_split\n",
197 | "from sklearn.preprocessing import StandardScaler\n",
198 | "\n",
199 | "charge_desc = data['c_charge_desc'].apply(lambda x: x if isinstance(x, str) else '')\n",
200 | "count_vectorizer = CountVectorizer(\n",
201 | " max_df=0.85, stop_words='english',\n",
202 | " max_features=100, decode_error='ignore'\n",
203 | ")\n",
204 | "charge_desc_features = count_vectorizer.fit_transform(charge_desc)\n",
205 | "\n",
206 | "one_hot_encoder = OneHotEncoder()\n",
207 | "charge_degree_features = one_hot_encoder.fit_transform(\n",
208 | " data['c_charge_degree']\n",
209 | ")\n",
210 | "\n",
211 | "data['race_black'] = data['race'].apply(lambda x: x == 'African-American').astype(int)\n",
212 | "stratification = data['race_black'] + (data['is_recid']).astype(int) * 2\n"
213 | ]
214 | },
215 | {
216 | "cell_type": "code",
217 | "execution_count": 20,
218 | "metadata": {
219 | "colab": {},
220 | "colab_type": "code",
221 | "id": "P7vdYFUHSRzi"
222 | },
223 | "outputs": [],
224 | "source": [
225 | "import numpy as np\n",
226 | "y = data['is_recid']\n",
227 | "X = pd.DataFrame(\n",
228 | " data=np.column_stack(\n",
229 | " [data[['juv_fel_count', 'juv_misd_count',\n",
230 | " 'juv_other_count', 'priors_count', 'days_b_screening_arrest']], \n",
231 | " charge_degree_features, \n",
232 | " charge_desc_features.todense()\n",
233 | " ]\n",
234 | " ),\n",
235 | " columns=['juv_fel_count', 'juv_misd_count', 'juv_other_count', 'priors_count', 'days_b_screening_arrest'] \\\n",
236 | " + one_hot_encoder.get_feature_names() \\\n",
237 | " + count_vectorizer.get_feature_names(),\n",
238 | " index=data.index\n",
239 | ")\n",
240 | "X['jailed_days'] = (data['c_jail_out'] - data['c_jail_in']).apply(lambda x: abs(x.days))\n",
241 | "X['waiting_jail_days'] = (data['c_jail_in'] - data['c_offense_date']).apply(lambda x: abs(x.days))\n",
242 | "X['waiting_arrest_days'] = (data['c_arrest_date'] - data['c_offense_date']).apply(lambda x: abs(x.days))\n",
243 | "X.fillna(0, inplace=True)\n",
244 | "\n",
245 | "columns = list(X.columns)\n",
246 | "X_train, X_test, y_train, y_test = train_test_split(\n",
247 | " X, y, test_size=0.33,\n",
248 | " random_state=42,\n",
249 | " stratify=stratification\n",
250 | ") # we stratify by black and the target\n"
251 | ]
252 | },
253 | {
254 | "cell_type": "code",
255 | "execution_count": 25,
256 | "metadata": {
257 | "colab": {},
258 | "colab_type": "code",
259 | "id": "hWtkwLMvSiRr"
260 | },
261 | "outputs": [],
262 | "source": [
263 | "import jax.numpy as jnp\n",
264 | "from jax import grad, jit, vmap, ops, lax\n",
265 | "import numpy as onp\n",
266 | "import numpy.random as npr\n",
267 | "import random\n",
268 | "from tqdm import trange\n",
269 | "from sklearn.base import ClassifierMixin\n",
270 | "from sklearn.preprocessing import StandardScaler\n",
271 | "\n",
272 | "\n",
273 | "class JAXLearner(ClassifierMixin):\n",
274 | " def __init__(self, layer_sizes=[10, 5, 1], epochs=20, batch_size=500, lr=1e-2):\n",
275 | " self.params = self.construct_network(layer_sizes)\n",
276 | " self.perex_grads = jit(grad(self.error))\n",
277 | " self.epochs = epochs\n",
278 | " self.batch_size = batch_size\n",
279 | " self.lr = lr\n",
280 | "\n",
281 | " @staticmethod\n",
282 | " def construct_network(layer_sizes=[10, 5, 1]):\n",
283 | " '''Please make sure your final layer corresponds to targets in dimensions.\n",
284 | " '''\n",
285 | " def init_layer(n_in, n_out):\n",
286 | " W = npr.randn(n_in, n_out)\n",
287 | " b = npr.randn(n_out,)\n",
288 | " return W, b\n",
289 | " \n",
290 | " return list(map(init_layer, layer_sizes[:-1], layer_sizes[1:]))\n",
291 | "\n",
292 | " @staticmethod\n",
293 | " def sigmoid(X): # or tanh\n",
294 | " return 1/(1+jnp.exp(-X))\n",
295 | "\n",
296 | " def _predict(self, inputs):\n",
297 | " for W, b in self.params:\n",
298 | " outputs = jnp.dot(inputs, W) + b\n",
299 | " inputs = self.sigmoid(outputs)\n",
300 | " return outputs\n",
301 | "\n",
302 | " def predict(self, inputs):\n",
303 | " inputs = self.standard_scaler.transform(inputs)\n",
304 | " return onp.asarray(self._predict(inputs))\n",
305 | "\n",
306 | " @staticmethod\n",
307 | " def mse(preds, targets, other=None):\n",
308 | " return jnp.sqrt(jnp.sum((preds - targets)**2))\n",
309 | "\n",
310 | " @staticmethod\n",
311 | " def penalized_mse(preds, targets, sensitive):\n",
312 | " err = jnp.sum((preds - targets)**2)\n",
313 | " err_s = jnp.sum((preds * sensitive - targets * sensitive)**2)\n",
314 | " penalty = jnp.clip(err_s / err, 1.0, 2.0)\n",
315 | " return err * penalty\n",
316 | "\n",
317 | " def error(self, params, inputs, targets, sensitive):\n",
318 | " preds = self._predict(inputs)\n",
319 | " return self.penalized_mse(preds, targets, sensitive)\n",
320 | "\n",
321 | " def fit(self, X, y, sensitive):\n",
322 | " self.standard_scaler = StandardScaler()\n",
323 | " X = self.standard_scaler.fit_transform(X)\n",
324 | " N = X.shape[0]\n",
325 | " indexes = list(range(N))\n",
326 | " steps_per_epoch = N // self.batch_size\n",
327 | "\n",
328 | " for epoch in trange(self.epochs, desc='training'):\n",
329 | " random.shuffle(indexes)\n",
330 | " index_offset = 0\n",
331 | " for step in trange(steps_per_epoch, desc='iteration'):\n",
332 | " grads = self.perex_grads(\n",
333 | " self.params, \n",
334 | " X[indexes[index_offset:index_offset+self.batch_size], :], \n",
335 | " y[indexes[index_offset:index_offset+self.batch_size]],\n",
336 | " sensitive[indexes[index_offset:index_offset+self.batch_size]]\n",
337 | " )\n",
338 | " # print(grads)\n",
339 | " self.params = [(W - self.lr * dW, b - self.lr * db)\n",
340 | " for (W, b), (dW, db) in zip(self.params, grads)]\n",
341 | " index_offset += self.batch_size\n"
342 | ]
343 | },
344 | {
345 | "cell_type": "code",
346 | "execution_count": 26,
347 | "metadata": {
348 | "colab": {
349 | "base_uri": "https://localhost:8080/",
350 | "height": 714
351 | },
352 | "colab_type": "code",
353 | "id": "a45RcwvtSqhP",
354 | "outputId": "6e1625e7-e527-4ce0-d401-94722cdb5cfb"
355 | },
356 | "outputs": [
357 | {
358 | "name": "stderr",
359 | "output_type": "stream",
360 | "text": [
361 | "training: 0%| | 0/20 [00:00, ?it/s]\n",
362 | "iteration: 100%|██████████| 8/8 [00:00<00:00, 460.32it/s]\n",
363 | "\n",
364 | "iteration: 100%|██████████| 8/8 [00:00<00:00, 576.52it/s]\n",
365 | "\n",
366 | "iteration: 100%|██████████| 8/8 [00:00<00:00, 627.77it/s]\n",
367 | "\n",
368 | "iteration: 100%|██████████| 8/8 [00:00<00:00, 650.54it/s]\n",
369 | "\n",
370 | "iteration: 100%|██████████| 8/8 [00:00<00:00, 737.82it/s]\n",
371 | "training: 25%|██▌ | 5/20 [00:00<00:00, 48.03it/s]\n",
372 | "iteration: 100%|██████████| 8/8 [00:00<00:00, 638.95it/s]\n",
373 | "\n",
374 | "iteration: 100%|██████████| 8/8 [00:00<00:00, 782.45it/s]\n",
375 | "\n",
376 | "iteration: 100%|██████████| 8/8 [00:00<00:00, 601.05it/s]\n",
377 | "\n",
378 | "iteration: 100%|██████████| 8/8 [00:00<00:00, 807.22it/s]\n",
379 | "\n",
380 | "iteration: 100%|██████████| 8/8 [00:00<00:00, 628.20it/s]\n",
381 | "\n",
382 | "iteration: 100%|██████████| 8/8 [00:00<00:00, 640.28it/s]\n",
383 | "training: 55%|█████▌ | 11/20 [00:00<00:00, 49.09it/s]\n",
384 | "iteration: 100%|██████████| 8/8 [00:00<00:00, 691.06it/s]\n",
385 | "\n",
386 | "iteration: 100%|██████████| 8/8 [00:00<00:00, 643.61it/s]\n",
387 | "\n",
388 | "iteration: 100%|██████████| 8/8 [00:00<00:00, 808.37it/s]\n",
389 | "\n",
390 | "iteration: 100%|██████████| 8/8 [00:00<00:00, 686.75it/s]\n",
391 | "\n",
392 | "iteration: 100%|██████████| 8/8 [00:00<00:00, 810.38it/s]\n",
393 | "\n",
394 | "iteration: 100%|██████████| 8/8 [00:00<00:00, 682.69it/s]\n",
395 | "training: 85%|████████▌ | 17/20 [00:00<00:00, 50.15it/s]\n",
396 | "iteration: 100%|██████████| 8/8 [00:00<00:00, 682.21it/s]\n",
397 | "\n",
398 | "iteration: 100%|██████████| 8/8 [00:00<00:00, 667.15it/s]\n",
399 | "\n",
400 | "iteration: 100%|██████████| 8/8 [00:00<00:00, 605.59it/s]\n",
401 | "training: 100%|██████████| 20/20 [00:00<00:00, 50.86it/s]\n"
402 | ]
403 | }
404 | ],
405 | "source": [
406 | "sensitive_train = X_train.join(\n",
407 | " data, rsuffix='_right'\n",
408 | ")['race_black']\n",
409 | "jax_learner = JAXLearner([X.values.shape[1], 100, 1])\n",
410 | "jax_learner.fit(\n",
411 | " X_train.values,\n",
412 | " y_train.values,\n",
413 | " sensitive_train.values\n",
414 | ")\n"
415 | ]
416 | },
417 | {
418 | "cell_type": "code",
419 | "execution_count": 28,
420 | "metadata": {
421 | "colab": {},
422 | "colab_type": "code",
423 | "id": "Kl5qF97cS9OL"
424 | },
425 | "outputs": [],
426 | "source": [
427 | "X_predicted = pd.DataFrame(\n",
428 | " data=jax_learner.predict(\n",
429 | " X_test.values\n",
430 | " ) * 10,\n",
431 | " columns=['score'], \n",
432 | " index=X_test.index\n",
433 | ").join(\n",
434 | " data[['sex', 'race', 'is_recid']], \n",
435 | " rsuffix='_right'\n",
436 | ")"
437 | ]
438 | },
439 | {
440 | "cell_type": "code",
441 | "execution_count": 31,
442 | "metadata": {
443 | "colab": {
444 | "base_uri": "https://localhost:8080/",
445 | "height": 419
446 | },
447 | "colab_type": "code",
448 | "id": "KakKvALIVCLj",
449 | "outputId": "0734b142-471e-4750-9e36-3586a22b2ce4"
450 | },
451 | "outputs": [
452 | {
453 | "data": {
454 | "text/html": [
455 | "\n",
456 | "\n",
469 | "
\n",
470 | " \n",
471 | " \n",
472 | " \n",
473 | " score \n",
474 | " sex \n",
475 | " race \n",
476 | " is_recid \n",
477 | " \n",
478 | " \n",
479 | " \n",
480 | " \n",
481 | " 6553 \n",
482 | " 45.214291 \n",
483 | " Male \n",
484 | " Caucasian \n",
485 | " 0 \n",
486 | " \n",
487 | " \n",
488 | " 1441 \n",
489 | " -42.396706 \n",
490 | " Male \n",
491 | " African-American \n",
492 | " 1 \n",
493 | " \n",
494 | " \n",
495 | " 2306 \n",
496 | " -60.853111 \n",
497 | " Male \n",
498 | " Caucasian \n",
499 | " 0 \n",
500 | " \n",
501 | " \n",
502 | " 504 \n",
503 | " -12.313410 \n",
504 | " Male \n",
505 | " Caucasian \n",
506 | " 0 \n",
507 | " \n",
508 | " \n",
509 | " 5212 \n",
510 | " 27.922726 \n",
511 | " Male \n",
512 | " African-American \n",
513 | " 0 \n",
514 | " \n",
515 | " \n",
516 | " ... \n",
517 | " ... \n",
518 | " ... \n",
519 | " ... \n",
520 | " ... \n",
521 | " \n",
522 | " \n",
523 | " 6118 \n",
524 | " 46.818214 \n",
525 | " Male \n",
526 | " Caucasian \n",
527 | " 1 \n",
528 | " \n",
529 | " \n",
530 | " 607 \n",
531 | " 9.088397 \n",
532 | " Female \n",
533 | " Caucasian \n",
534 | " 0 \n",
535 | " \n",
536 | " \n",
537 | " 2596 \n",
538 | " -12.438754 \n",
539 | " Female \n",
540 | " Caucasian \n",
541 | " 1 \n",
542 | " \n",
543 | " \n",
544 | " 3204 \n",
545 | " -87.487778 \n",
546 | " Male \n",
547 | " African-American \n",
548 | " 1 \n",
549 | " \n",
550 | " \n",
551 | " 4692 \n",
552 | " 36.880836 \n",
553 | " Female \n",
554 | " African-American \n",
555 | " 0 \n",
556 | " \n",
557 | " \n",
558 | "
\n",
559 | "
2052 rows × 4 columns
\n",
560 | "
"
561 | ],
562 | "text/plain": [
563 | " score sex race is_recid\n",
564 | "6553 45.214291 Male Caucasian 0\n",
565 | "1441 -42.396706 Male African-American 1\n",
566 | "2306 -60.853111 Male Caucasian 0\n",
567 | "504 -12.313410 Male Caucasian 0\n",
568 | "5212 27.922726 Male African-American 0\n",
569 | "... ... ... ... ...\n",
570 | "6118 46.818214 Male Caucasian 1\n",
571 | "607 9.088397 Female Caucasian 0\n",
572 | "2596 -12.438754 Female Caucasian 1\n",
573 | "3204 -87.487778 Male African-American 1\n",
574 | "4692 36.880836 Female African-American 0\n",
575 | "\n",
576 | "[2052 rows x 4 columns]"
577 | ]
578 | },
579 | "execution_count": 31,
580 | "metadata": {
581 | "tags": []
582 | },
583 | "output_type": "execute_result"
584 | }
585 | ],
586 | "source": [
587 | "X_predicted"
588 | ]
589 | },
590 | {
591 | "cell_type": "code",
592 | "execution_count": 32,
593 | "metadata": {
594 | "colab": {
595 | "base_uri": "https://localhost:8080/",
596 | "height": 266
597 | },
598 | "colab_type": "code",
599 | "id": "D_T4dZAYS-sP",
600 | "outputId": "b5399c87-dfef-4aa2-a7b8-c1780749eeee"
601 | },
602 | "outputs": [
603 | {
604 | "data": {
605 | "text/html": [
606 | "\n",
607 | "\n",
620 | "
\n",
621 | " \n",
622 | " \n",
623 | " \n",
624 | " reoffend \n",
625 | " score \n",
626 | " N \n",
627 | " FPR \n",
628 | " FNR \n",
629 | " DFP \n",
630 | " DFN \n",
631 | " \n",
632 | " \n",
633 | " sensitive \n",
634 | " \n",
635 | " \n",
636 | " \n",
637 | " \n",
638 | " \n",
639 | " \n",
640 | " \n",
641 | " \n",
642 | " \n",
643 | " \n",
644 | " \n",
645 | " African-American \n",
646 | " 0.471042 \n",
647 | " -0.948788 \n",
648 | " 1036 \n",
649 | " 0.385036 \n",
650 | " 0.487842 \n",
651 | " 0.868148 \n",
652 | " 1.401687 \n",
653 | " \n",
654 | " \n",
655 | " Asian \n",
656 | " 0.222222 \n",
657 | " 1.802761 \n",
658 | " 9 \n",
659 | " 0.571429 \n",
660 | " 0.250000 \n",
661 | " 1.288410 \n",
662 | " 0.718310 \n",
663 | " \n",
664 | " \n",
665 | " Caucasian \n",
666 | " 0.335188 \n",
667 | " -0.308411 \n",
668 | " 719 \n",
669 | " 0.443515 \n",
670 | " 0.348039 \n",
671 | " 1.000000 \n",
672 | " 1.000000 \n",
673 | " \n",
674 | " \n",
675 | " Hispanic \n",
676 | " 0.293478 \n",
677 | " 0.121934 \n",
678 | " 184 \n",
679 | " 0.500000 \n",
680 | " 0.329897 \n",
681 | " 1.127358 \n",
682 | " 0.947873 \n",
683 | " \n",
684 | " \n",
685 | " Native American \n",
686 | " 0.500000 \n",
687 | " 0.963587 \n",
688 | " 2 \n",
689 | " 0.000000 \n",
690 | " 0.000000 \n",
691 | " 0.000000 \n",
692 | " 0.000000 \n",
693 | " \n",
694 | " \n",
695 | " Other \n",
696 | " 0.294118 \n",
697 | " -0.626596 \n",
698 | " 102 \n",
699 | " 0.430556 \n",
700 | " 0.267857 \n",
701 | " 0.970781 \n",
702 | " 0.769618 \n",
703 | " \n",
704 | " \n",
705 | "
\n",
706 | "
"
707 | ],
708 | "text/plain": [
709 | " reoffend score N ... FNR DFP DFN\n",
710 | "sensitive ... \n",
711 | "African-American 0.471042 -0.948788 1036 ... 0.487842 0.868148 1.401687\n",
712 | "Asian 0.222222 1.802761 9 ... 0.250000 1.288410 0.718310\n",
713 | "Caucasian 0.335188 -0.308411 719 ... 0.348039 1.000000 1.000000\n",
714 | "Hispanic 0.293478 0.121934 184 ... 0.329897 1.127358 0.947873\n",
715 | "Native American 0.500000 0.963587 2 ... 0.000000 0.000000 0.000000\n",
716 | "Other 0.294118 -0.626596 102 ... 0.267857 0.970781 0.769618\n",
717 | "\n",
718 | "[6 rows x 7 columns]"
719 | ]
720 | },
721 | "execution_count": 32,
722 | "metadata": {
723 | "tags": []
724 | },
725 | "output_type": "execute_result"
726 | }
727 | ],
728 | "source": [
729 | "calculate_impacts(X_predicted, score_col='score')"
730 | ]
731 | },
732 | {
733 | "cell_type": "code",
734 | "execution_count": null,
735 | "metadata": {
736 | "colab": {},
737 | "colab_type": "code",
738 | "id": "IPWiMPpyTUI7"
739 | },
740 | "outputs": [],
741 | "source": []
742 | }
743 | ],
744 | "metadata": {
745 | "colab": {
746 | "collapsed_sections": [],
747 | "name": "recidivism + active learning",
748 | "provenance": []
749 | },
750 | "kernelspec": {
751 | "display_name": "Python 3",
752 | "language": "python",
753 | "name": "python3"
754 | },
755 | "language_info": {
756 | "codemirror_mode": {
757 | "name": "ipython",
758 | "version": 3
759 | },
760 | "file_extension": ".py",
761 | "mimetype": "text/x-python",
762 | "name": "python",
763 | "nbconvert_exporter": "python",
764 | "pygments_lexer": "ipython3",
765 | "version": "3.8.3"
766 | }
767 | },
768 | "nbformat": 4,
769 | "nbformat_minor": 1
770 | }
771 |
--------------------------------------------------------------------------------
/chapter03/Recommending_products.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "accelerator": "GPU",
6 | "colab": {
7 | "name": "Recommending_products",
8 | "provenance": [],
9 | "collapsed_sections": []
10 | },
11 | "kernelspec": {
12 | "display_name": "Python 3",
13 | "name": "python3"
14 | }
15 | },
16 | "cells": [
17 | {
18 | "cell_type": "markdown",
19 | "metadata": {
20 | "id": "NmAgD1Ta057u"
21 | },
22 | "source": [
23 | "Dataset\n",
24 | "----------\n",
25 | "http://cseweb.ucsd.edu/~jmcauley/datasets.html#goodreads\n",
26 | "* Items:\t1,561,465\n",
27 | "* Users:\t808,749\n",
28 | "* Interactions:\t225,394,930\n",
29 | "\n",
30 | "```json\n",
31 | "{\n",
32 | " \"user_id\": \"8842281e1d1347389f2ab93d60773d4d\",\n",
33 | " \"book_id\": \"130580\",\n",
34 | " \"review_id\": \"330f9c153c8d3347eb914c06b89c94da\",\n",
35 | " \"isRead\": true,\n",
36 | " \"rating\": 4,\n",
37 | " \"date_added\": \"Mon Aug 01 13:41:57 -0700 2011\",\n",
38 | " \"date_updated\": \"Mon Aug 01 13:42:41 -0700 2011\",\n",
39 | " \"read_at\": \"Fri Jan 01 00:00:00 -0800 1988\",\n",
40 | " \"started_at\": \"\"\n",
41 | "}\n",
42 | "```\n",
43 | "\n",
44 | "see also https://snap.stanford.edu/data/amazon-meta.html\n",
45 | "\n",
46 | "\n"
47 | ]
48 | },
49 | {
50 | "cell_type": "code",
51 | "metadata": {
52 | "id": "ch-tCyn66gPI",
53 | "outputId": "197ba60c-cee2-4dee-bd6b-592a919596f1",
54 | "colab": {
55 | "base_uri": "https://localhost:8080/",
56 | "height": 258
57 | }
58 | },
59 | "source": [
60 | "pip install git+https://github.com/maciejkula/spotlight.git"
61 | ],
62 | "execution_count": null,
63 | "outputs": [
64 | {
65 | "output_type": "stream",
66 | "text": [
67 | "Collecting git+https://github.com/maciejkula/spotlight.git\n",
68 | " Cloning https://github.com/maciejkula/spotlight.git to /tmp/pip-req-build-myzn2a4l\n",
69 | " Running command git clone -q https://github.com/maciejkula/spotlight.git /tmp/pip-req-build-myzn2a4l\n",
70 | "Requirement already satisfied: torch>=0.4.0 in /usr/local/lib/python3.6/dist-packages (from spotlight==0.1.6) (1.6.0+cu101)\n",
71 | "Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from torch>=0.4.0->spotlight==0.1.6) (1.18.5)\n",
72 | "Requirement already satisfied: future in /usr/local/lib/python3.6/dist-packages (from torch>=0.4.0->spotlight==0.1.6) (0.16.0)\n",
73 | "Building wheels for collected packages: spotlight\n",
74 | " Building wheel for spotlight (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
75 | " Created wheel for spotlight: filename=spotlight-0.1.6-cp36-none-any.whl size=33919 sha256=320c36fcaffb21d494d593397c529c6c75804929542b337f01d5607e3c1598f7\n",
76 | " Stored in directory: /tmp/pip-ephem-wheel-cache-fdvnpz7v/wheels/0a/33/c8/e8510ea648aaacf6031e128dfa92bcd3750f02db2aaf0922fe\n",
77 | "Successfully built spotlight\n",
78 | "Installing collected packages: spotlight\n",
79 | "Successfully installed spotlight-0.1.6\n"
80 | ],
81 | "name": "stdout"
82 | }
83 | ]
84 | },
85 | {
86 | "cell_type": "code",
87 | "metadata": {
88 | "id": "fyJ0Beol6jqd"
89 | },
90 | "source": [
91 | "from spotlight.datasets.goodbooks import get_goodbooks_dataset, _get_dataset\n",
92 | "from spotlight.interactions import Interactions\n"
93 | ],
94 | "execution_count": null,
95 | "outputs": []
96 | },
97 | {
98 | "cell_type": "code",
99 | "metadata": {
100 | "id": "o11d-z9xW23o",
101 | "outputId": "70de19f0-96e1-4929-ee99-bc5ff0ec103b",
102 | "colab": {
103 | "base_uri": "https://localhost:8080/",
104 | "height": 224
105 | }
106 | },
107 | "source": [
108 | "!wget https://raw.githubusercontent.com/zygmuntz/goodbooks-10k/master/books.csv"
109 | ],
110 | "execution_count": null,
111 | "outputs": [
112 | {
113 | "output_type": "stream",
114 | "text": [
115 | "--2020-10-18 20:26:25-- https://raw.githubusercontent.com/zygmuntz/goodbooks-10k/master/books.csv\n",
116 | "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 151.101.0.133, 151.101.64.133, 151.101.128.133, ...\n",
117 | "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|151.101.0.133|:443... connected.\n",
118 | "HTTP request sent, awaiting response... 200 OK\n",
119 | "Length: 3286659 (3.1M) [text/plain]\n",
120 | "Saving to: ‘books.csv’\n",
121 | "\n",
122 | "books.csv 100%[===================>] 3.13M --.-KB/s in 0.1s \n",
123 | "\n",
124 | "2020-10-18 20:26:25 (25.4 MB/s) - ‘books.csv’ saved [3286659/3286659]\n",
125 | "\n"
126 | ],
127 | "name": "stdout"
128 | }
129 | ]
130 | },
131 | {
132 | "cell_type": "code",
133 | "metadata": {
134 | "id": "DawSG98zWFOQ"
135 | },
136 | "source": [
137 | "import pandas as pd\n",
138 | "books = pd.read_csv('books.csv', index_col=0)"
139 | ],
140 | "execution_count": null,
141 | "outputs": []
142 | },
143 | {
144 | "cell_type": "code",
145 | "metadata": {
146 | "id": "o8HTfPwRW8F_",
147 | "outputId": "cb39a8a2-f237-411c-9c9d-c9e291fdf0dd",
148 | "colab": {
149 | "base_uri": "https://localhost:8080/",
150 | "height": 459
151 | }
152 | },
153 | "source": [
154 | "books.head()"
155 | ],
156 | "execution_count": null,
157 | "outputs": [
158 | {
159 | "output_type": "execute_result",
160 | "data": {
161 | "text/html": [
162 | "\n",
163 | "\n",
176 | "
\n",
177 | " \n",
178 | " \n",
179 | " \n",
180 | " goodreads_book_id \n",
181 | " best_book_id \n",
182 | " work_id \n",
183 | " books_count \n",
184 | " isbn \n",
185 | " isbn13 \n",
186 | " authors \n",
187 | " original_publication_year \n",
188 | " original_title \n",
189 | " title \n",
190 | " language_code \n",
191 | " average_rating \n",
192 | " ratings_count \n",
193 | " work_ratings_count \n",
194 | " work_text_reviews_count \n",
195 | " ratings_1 \n",
196 | " ratings_2 \n",
197 | " ratings_3 \n",
198 | " ratings_4 \n",
199 | " ratings_5 \n",
200 | " image_url \n",
201 | " small_image_url \n",
202 | " \n",
203 | " \n",
204 | " book_id \n",
205 | " \n",
206 | " \n",
207 | " \n",
208 | " \n",
209 | " \n",
210 | " \n",
211 | " \n",
212 | " \n",
213 | " \n",
214 | " \n",
215 | " \n",
216 | " \n",
217 | " \n",
218 | " \n",
219 | " \n",
220 | " \n",
221 | " \n",
222 | " \n",
223 | " \n",
224 | " \n",
225 | " \n",
226 | " \n",
227 | " \n",
228 | " \n",
229 | " \n",
230 | " \n",
231 | " 1 \n",
232 | " 2767052 \n",
233 | " 2767052 \n",
234 | " 2792775 \n",
235 | " 272 \n",
236 | " 439023483 \n",
237 | " 9.780439e+12 \n",
238 | " Suzanne Collins \n",
239 | " 2008.0 \n",
240 | " The Hunger Games \n",
241 | " The Hunger Games (The Hunger Games, #1) \n",
242 | " eng \n",
243 | " 4.34 \n",
244 | " 4780653 \n",
245 | " 4942365 \n",
246 | " 155254 \n",
247 | " 66715 \n",
248 | " 127936 \n",
249 | " 560092 \n",
250 | " 1481305 \n",
251 | " 2706317 \n",
252 | " https://images.gr-assets.com/books/1447303603m... \n",
253 | " https://images.gr-assets.com/books/1447303603s... \n",
254 | " \n",
255 | " \n",
256 | " 2 \n",
257 | " 3 \n",
258 | " 3 \n",
259 | " 4640799 \n",
260 | " 491 \n",
261 | " 439554934 \n",
262 | " 9.780440e+12 \n",
263 | " J.K. Rowling, Mary GrandPré \n",
264 | " 1997.0 \n",
265 | " Harry Potter and the Philosopher's Stone \n",
266 | " Harry Potter and the Sorcerer's Stone (Harry P... \n",
267 | " eng \n",
268 | " 4.44 \n",
269 | " 4602479 \n",
270 | " 4800065 \n",
271 | " 75867 \n",
272 | " 75504 \n",
273 | " 101676 \n",
274 | " 455024 \n",
275 | " 1156318 \n",
276 | " 3011543 \n",
277 | " https://images.gr-assets.com/books/1474154022m... \n",
278 | " https://images.gr-assets.com/books/1474154022s... \n",
279 | " \n",
280 | " \n",
281 | " 3 \n",
282 | " 41865 \n",
283 | " 41865 \n",
284 | " 3212258 \n",
285 | " 226 \n",
286 | " 316015849 \n",
287 | " 9.780316e+12 \n",
288 | " Stephenie Meyer \n",
289 | " 2005.0 \n",
290 | " Twilight \n",
291 | " Twilight (Twilight, #1) \n",
292 | " en-US \n",
293 | " 3.57 \n",
294 | " 3866839 \n",
295 | " 3916824 \n",
296 | " 95009 \n",
297 | " 456191 \n",
298 | " 436802 \n",
299 | " 793319 \n",
300 | " 875073 \n",
301 | " 1355439 \n",
302 | " https://images.gr-assets.com/books/1361039443m... \n",
303 | " https://images.gr-assets.com/books/1361039443s... \n",
304 | " \n",
305 | " \n",
306 | " 4 \n",
307 | " 2657 \n",
308 | " 2657 \n",
309 | " 3275794 \n",
310 | " 487 \n",
311 | " 61120081 \n",
312 | " 9.780061e+12 \n",
313 | " Harper Lee \n",
314 | " 1960.0 \n",
315 | " To Kill a Mockingbird \n",
316 | " To Kill a Mockingbird \n",
317 | " eng \n",
318 | " 4.25 \n",
319 | " 3198671 \n",
320 | " 3340896 \n",
321 | " 72586 \n",
322 | " 60427 \n",
323 | " 117415 \n",
324 | " 446835 \n",
325 | " 1001952 \n",
326 | " 1714267 \n",
327 | " https://images.gr-assets.com/books/1361975680m... \n",
328 | " https://images.gr-assets.com/books/1361975680s... \n",
329 | " \n",
330 | " \n",
331 | " 5 \n",
332 | " 4671 \n",
333 | " 4671 \n",
334 | " 245494 \n",
335 | " 1356 \n",
336 | " 743273567 \n",
337 | " 9.780743e+12 \n",
338 | " F. Scott Fitzgerald \n",
339 | " 1925.0 \n",
340 | " The Great Gatsby \n",
341 | " The Great Gatsby \n",
342 | " eng \n",
343 | " 3.89 \n",
344 | " 2683664 \n",
345 | " 2773745 \n",
346 | " 51992 \n",
347 | " 86236 \n",
348 | " 197621 \n",
349 | " 606158 \n",
350 | " 936012 \n",
351 | " 947718 \n",
352 | " https://images.gr-assets.com/books/1490528560m... \n",
353 | " https://images.gr-assets.com/books/1490528560s... \n",
354 | " \n",
355 | " \n",
356 | "
\n",
357 | "
"
358 | ],
359 | "text/plain": [
360 | " goodreads_book_id ... small_image_url\n",
361 | "book_id ... \n",
362 | "1 2767052 ... https://images.gr-assets.com/books/1447303603s...\n",
363 | "2 3 ... https://images.gr-assets.com/books/1474154022s...\n",
364 | "3 41865 ... https://images.gr-assets.com/books/1361039443s...\n",
365 | "4 2657 ... https://images.gr-assets.com/books/1361975680s...\n",
366 | "5 4671 ... https://images.gr-assets.com/books/1490528560s...\n",
367 | "\n",
368 | "[5 rows x 22 columns]"
369 | ]
370 | },
371 | "metadata": {
372 | "tags": []
373 | },
374 | "execution_count": 5
375 | }
376 | ]
377 | },
378 | {
379 | "cell_type": "code",
380 | "metadata": {
381 | "id": "IUeJXtMpXKOj"
382 | },
383 | "source": [
384 | "def get_book_titles(book_ids):\n",
385 | " '''Get book titles by book ids\n",
386 | " Example:\n",
387 | " --------\n",
388 | " >> get_book_titles(1)\n",
389 | " ['The Hunger Games (The Hunger Games, #1)']\n",
390 | " '''\n",
391 | " if isinstance(book_ids, int):\n",
392 | " book_ids = [book_ids]\n",
393 | " titles = []\n",
394 | " for book_id in book_ids:\n",
395 | " titles.append(books.loc[book_id, 'title'])\n",
396 | " return titles"
397 | ],
398 | "execution_count": null,
399 | "outputs": []
400 | },
401 | {
402 | "cell_type": "code",
403 | "metadata": {
404 | "id": "pZlOY_lC4Mkg"
405 | },
406 | "source": [
407 | "data = _get_dataset()\n",
408 | "interactions = Interactions(*data)"
409 | ],
410 | "execution_count": null,
411 | "outputs": []
412 | },
413 | {
414 | "cell_type": "code",
415 | "metadata": {
416 | "id": "AMi2kZX_VRkk",
417 | "outputId": "fb5fedc5-27b8-438c-e02f-7d814012470c",
418 | "colab": {
419 | "base_uri": "https://localhost:8080/",
420 | "height": 102
421 | }
422 | },
423 | "source": [
424 | "data"
425 | ],
426 | "execution_count": null,
427 | "outputs": [
428 | {
429 | "output_type": "execute_result",
430 | "data": {
431 | "text/plain": [
432 | "(array([ 1, 2, 2, ..., 49925, 49925, 49925], dtype=int32),\n",
433 | " array([ 258, 4081, 260, ..., 722, 949, 1023], dtype=int32),\n",
434 | " array([5., 4., 5., ..., 4., 5., 4.], dtype=float32),\n",
435 | " array([ 0, 1, 2, ..., 5976476, 5976477, 5976478],\n",
436 | " dtype=int32))"
437 | ]
438 | },
439 | "metadata": {
440 | "tags": []
441 | },
442 | "execution_count": 8
443 | }
444 | ]
445 | },
446 | {
447 | "cell_type": "code",
448 | "metadata": {
449 | "id": "BKtSQq047pju",
450 | "outputId": "9fdfda85-066f-4b62-c54a-82b45aac6b4e",
451 | "colab": {
452 | "base_uri": "https://localhost:8080/",
453 | "height": 34
454 | }
455 | },
456 | "source": [
457 | "print(interactions)"
458 | ],
459 | "execution_count": null,
460 | "outputs": [
461 | {
462 | "output_type": "stream",
463 | "text": [
464 | "\n"
465 | ],
466 | "name": "stdout"
467 | }
468 | ]
469 | },
470 | {
471 | "cell_type": "code",
472 | "metadata": {
473 | "id": "By5n99R57yrp"
474 | },
475 | "source": [
476 | "import torch\n",
477 | "\n",
478 | "from spotlight.factorization.explicit import ExplicitFactorizationModel\n",
479 | "\n",
480 | "model = ExplicitFactorizationModel(loss='regression',\n",
481 | " embedding_dim=128, # latent dimensionality\n",
482 | " n_iter=10, # number of epochs of training\n",
483 | " batch_size=1024, # minibatch size\n",
484 | " l2=1e-9, # strength of L2 regularization\n",
485 | " learning_rate=1e-3,\n",
486 | " use_cuda=torch.cuda.is_available())\n"
487 | ],
488 | "execution_count": 30,
489 | "outputs": []
490 | },
491 | {
492 | "cell_type": "code",
493 | "metadata": {
494 | "id": "fO9GM0Zx7z_x"
495 | },
496 | "source": [
497 | "from spotlight.cross_validation import random_train_test_split\n",
498 | "import numpy as np\n",
499 | "\n",
500 | "train, test = random_train_test_split(interactions, random_state=np.random.RandomState(42))\n"
501 | ],
502 | "execution_count": 31,
503 | "outputs": []
504 | },
505 | {
506 | "cell_type": "code",
507 | "metadata": {
508 | "id": "Nk7jhvaZ8FsD",
509 | "outputId": "0c561ba5-6557-4a3e-960c-628f335fadda",
510 | "colab": {
511 | "base_uri": "https://localhost:8080/",
512 | "height": 68
513 | }
514 | },
515 | "source": [
516 | "print('Split into \\n {} and \\n {}.'.format(train, test))"
517 | ],
518 | "execution_count": 32,
519 | "outputs": [
520 | {
521 | "output_type": "stream",
522 | "text": [
523 | "Split into \n",
524 | " and \n",
525 | " .\n"
526 | ],
527 | "name": "stdout"
528 | }
529 | ]
530 | },
531 | {
532 | "cell_type": "code",
533 | "metadata": {
534 | "id": "ekZwgODA8Jz0",
535 | "outputId": "e39cfd35-d737-4635-c84a-93194f67c5ec",
536 | "colab": {
537 | "base_uri": "https://localhost:8080/",
538 | "height": 238
539 | }
540 | },
541 | "source": [
542 | "model.fit(train, verbose=True)\n",
543 | "from spotlight.evaluation import rmse_score, precision_recall_score\n",
544 | "\n",
545 | "train_rmse = rmse_score(model, train)\n",
546 | "test_rmse = rmse_score(model, test)\n",
547 | "train_precision, train_recall = precision_recall_score(model, train, k=5)\n",
548 | "test_precision, test_recall = precision_recall_score(model, test, k=5)\n",
549 | "\n",
550 | "print('Train RMSE {:.3f}, test RMSE {:.3f}'.format(train_rmse, test_rmse))\n",
551 | "print(\n",
552 | " 'mean train precision at 5: {:.3f}'.format(\n",
553 | " train_precision.mean()\n",
554 | "))\n",
555 | "print(\n",
556 | " 'mean test precision at 5: {:.3f}'.format(\n",
557 | " test_precision.mean()\n",
558 | "))"
559 | ],
560 | "execution_count": 33,
561 | "outputs": [
562 | {
563 | "output_type": "stream",
564 | "text": [
565 | "Epoch 0: loss 2.762984432512994\n",
566 | "Epoch 1: loss 0.7377435095815638\n",
567 | "Epoch 2: loss 0.6588062686379001\n",
568 | "Epoch 3: loss 0.5621107913633485\n",
569 | "Epoch 4: loss 0.44322063551469837\n",
570 | "Epoch 5: loss 0.32800119822681334\n",
571 | "Epoch 6: loss 0.2381617845189648\n",
572 | "Epoch 7: loss 0.1754616707276226\n",
573 | "Epoch 8: loss 0.1336526694912982\n",
574 | "Epoch 9: loss 0.10600318515422003\n",
575 | "Train RMSE 0.265, test RMSE 0.965\n",
576 | "mean train precision at 5: 0.027\n",
577 | "mean test precision at 5: 0.020\n"
578 | ],
579 | "name": "stdout"
580 | }
581 | ]
582 | },
583 | {
584 | "cell_type": "code",
585 | "metadata": {
586 | "id": "Fs9I0Vk0oHiT"
587 | },
588 | "source": [
589 | "# explaining predictions. Based on https://github.com/lyst/lightfm/blob/master/examples/quickstart/quickstart.ipynb\n",
590 | "\n",
591 | "def sample_recommendation(model, user_ids, train, item_labels):\n",
592 | " '''Give recommendations for users given a model and explain recommendations.\n",
593 | " '''\n",
594 | " n_users, n_items = train.shape\n",
595 | "\n",
596 | " for user_id in user_ids:\n",
597 | " known_positives = item_labels[train[user_id].indices]\n",
598 | " \n",
599 | " scores = model.predict(user_id, np.arange(n_items))\n",
600 | " top_items = item_labels[np.argsort(-scores)]\n",
601 | " \n",
602 | " print(\"User %s\" % user_id)\n",
603 | " print(\" Known positives:\")\n",
604 | " \n",
605 | " for x in known_positives[:3]:\n",
606 | " print(\" %s\" % x)\n",
607 | "\n",
608 | " print(\" Recommended:\")\n",
609 | " \n",
610 | " for x in top_items[:3]:\n",
611 | " print(\" %s\" % x)"
612 | ],
613 | "execution_count": 34,
614 | "outputs": []
615 | },
616 | {
617 | "cell_type": "code",
618 | "metadata": {
619 | "id": "TCHsBHOLoTgD",
620 | "outputId": "adb055f1-1e60-43c0-ed58-1b010c58779f",
621 | "colab": {
622 | "base_uri": "https://localhost:8080/",
623 | "height": 187
624 | }
625 | },
626 | "source": [
627 | "book_labels = get_book_titles(list(train.item_ids))\n",
628 | "book_labels[:10]"
629 | ],
630 | "execution_count": 35,
631 | "outputs": [
632 | {
633 | "output_type": "execute_result",
634 | "data": {
635 | "text/plain": [
636 | "[\"Ahab's Wife, or The Star-Gazer\",\n",
637 | " 'City of Glass (The Mortal Instruments, #3)',\n",
638 | " \"Enchanters' End Game (The Belgariad, #5)\",\n",
639 | " 'Frankenstein',\n",
640 | " 'The Atlantis Complex (Artemis Fowl, #7)',\n",
641 | " 'The Life and Times of the Thunderbolt Kid',\n",
642 | " 'A Game of Thrones (A Song of Ice and Fire, #1)',\n",
643 | " 'Disgrace',\n",
644 | " 'Beautiful Creatures (Caster Chronicles, #1)',\n",
645 | " 'The Alchemist']"
646 | ]
647 | },
648 | "metadata": {
649 | "tags": []
650 | },
651 | "execution_count": 35
652 | }
653 | ]
654 | },
655 | {
656 | "cell_type": "code",
657 | "metadata": {
658 | "id": "OIi2bGhEoc5l",
659 | "outputId": "e1a1af78-76d4-4d00-f64a-770471af9a6c",
660 | "colab": {
661 | "base_uri": "https://localhost:8080/",
662 | "height": 476
663 | }
664 | },
665 | "source": [
666 | "sample_recommendation(model, [3, 9999, 15000], train.tocsr(), np.array(book_labels))"
667 | ],
668 | "execution_count": 36,
669 | "outputs": [
670 | {
671 | "output_type": "stream",
672 | "text": [
673 | "User 3\n",
674 | " Known positives:\n",
675 | " The Atlantis Complex (Artemis Fowl, #7)\n",
676 | " Sentinel (Covenant, #5)\n",
677 | " The Devil Wears Prada (The Devil Wears Prada, #1)\n",
678 | " Recommended:\n",
679 | " Romeo and Juliet\n",
680 | " Altered Carbon (Takeshi Kovacs, #1)\n",
681 | " The Little Engine That Could\n",
682 | "User 9999\n",
683 | " Known positives:\n",
684 | " City of Glass (The Mortal Instruments, #3)\n",
685 | " The Magicians' Guild (Black Magician Trilogy, #1)\n",
686 | " Bridge to Terabithia\n",
687 | " Recommended:\n",
688 | " Darkness at Noon\n",
689 | " Magyk (Septimus Heap, #1)\n",
690 | " The Complete Stories and Poems\n",
691 | "User 15000\n",
692 | " Known positives:\n",
693 | " Enchanters' End Game (The Belgariad, #5)\n",
694 | " The Life and Times of the Thunderbolt Kid\n",
695 | " Beautiful Creatures (Caster Chronicles, #1)\n",
696 | " Recommended:\n",
697 | " Tales of a Fourth Grade Nothing (Fudge, #1)\n",
698 | " Wizard and Glass (The Dark Tower, #4)\n",
699 | " Plum Lovin' (Stephanie Plum, #12.5)\n"
700 | ],
701 | "name": "stdout"
702 | }
703 | ]
704 | },
705 | {
706 | "cell_type": "code",
707 | "metadata": {
708 | "id": "muIudpFLK_Q4",
709 | "outputId": "05edc387-8c95-4f3b-cdc5-6a774f860e5b",
710 | "colab": {
711 | "base_uri": "https://localhost:8080/",
712 | "height": 34
713 | }
714 | },
715 | "source": [
716 | "from spotlight.evaluation import precision_recall_score\n",
717 | "\n",
718 | "train_prs = precision_recall_score(model, train, k=5)\n",
719 | "test_prs = precision_recall_score(model, test, k=5)\n",
720 | "\n",
721 | "print('Train PRS {:.3f}, test PRS {:.3f}'.format(train_rmse, test_rmse))"
722 | ],
723 | "execution_count": 37,
724 | "outputs": [
725 | {
726 | "output_type": "stream",
727 | "text": [
728 | "Train PRS 0.265, test PRS 0.965\n"
729 | ],
730 | "name": "stdout"
731 | }
732 | ]
733 | },
734 | {
735 | "cell_type": "code",
736 | "metadata": {
737 | "id": "0HmxqtCuug8a",
738 | "outputId": "01a326ed-e1c0-4593-b645-f63338f1b004",
739 | "colab": {
740 | "base_uri": "https://localhost:8080/",
741 | "height": 51
742 | }
743 | },
744 | "source": [
745 | "print(\n",
746 | " 'mean train precision at 5: {:.3f}'.format(\n",
747 | " train_prs[0].mean()\n",
748 | "))\n",
749 | "print(\n",
750 | " 'mean test precision at 5: {:.3f}'.format(\n",
751 | " test_prs[0].mean()\n",
752 | "))"
753 | ],
754 | "execution_count": 38,
755 | "outputs": [
756 | {
757 | "output_type": "stream",
758 | "text": [
759 | "mean train precision at 5: 0.027\n",
760 | "mean test precision at 5: 0.020\n"
761 | ],
762 | "name": "stdout"
763 | }
764 | ]
765 | },
766 | {
767 | "cell_type": "code",
768 | "metadata": {
769 | "id": "UBpy156e9vGq",
770 | "outputId": "70ae3828-9430-4d12-8787-9d4e3513d5e4",
771 | "colab": {
772 | "base_uri": "https://localhost:8080/",
773 | "height": 326
774 | }
775 | },
776 | "source": [
777 | "!pip install lightfm"
778 | ],
779 | "execution_count": 47,
780 | "outputs": [
781 | {
782 | "output_type": "stream",
783 | "text": [
784 | "Collecting lightfm\n",
785 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/e9/8e/5485ac5a8616abe1c673d1e033e2f232b4319ab95424b42499fabff2257f/lightfm-1.15.tar.gz (302kB)\n",
786 | "\r\u001b[K |█ | 10kB 23.9MB/s eta 0:00:01\r\u001b[K |██▏ | 20kB 6.2MB/s eta 0:00:01\r\u001b[K |███▎ | 30kB 7.5MB/s eta 0:00:01\r\u001b[K |████▍ | 40kB 7.6MB/s eta 0:00:01\r\u001b[K |█████▍ | 51kB 6.8MB/s eta 0:00:01\r\u001b[K |██████▌ | 61kB 8.0MB/s eta 0:00:01\r\u001b[K |███████▋ | 71kB 8.4MB/s eta 0:00:01\r\u001b[K |████████▊ | 81kB 8.9MB/s eta 0:00:01\r\u001b[K |█████████▊ | 92kB 8.7MB/s eta 0:00:01\r\u001b[K |██████████▉ | 102kB 9.3MB/s eta 0:00:01\r\u001b[K |████████████ | 112kB 9.3MB/s eta 0:00:01\r\u001b[K |█████████████ | 122kB 9.3MB/s eta 0:00:01\r\u001b[K |██████████████ | 133kB 9.3MB/s eta 0:00:01\r\u001b[K |███████████████▏ | 143kB 9.3MB/s eta 0:00:01\r\u001b[K |████████████████▎ | 153kB 9.3MB/s eta 0:00:01\r\u001b[K |█████████████████▍ | 163kB 9.3MB/s eta 0:00:01\r\u001b[K |██████████████████▍ | 174kB 9.3MB/s eta 0:00:01\r\u001b[K |███████████████████▌ | 184kB 9.3MB/s eta 0:00:01\r\u001b[K |████████████████████▋ | 194kB 9.3MB/s eta 0:00:01\r\u001b[K |█████████████████████▊ | 204kB 9.3MB/s eta 0:00:01\r\u001b[K |██████████████████████▊ | 215kB 9.3MB/s eta 0:00:01\r\u001b[K |███████████████████████▉ | 225kB 9.3MB/s eta 0:00:01\r\u001b[K |█████████████████████████ | 235kB 9.3MB/s eta 0:00:01\r\u001b[K |██████████████████████████ | 245kB 9.3MB/s eta 0:00:01\r\u001b[K |███████████████████████████ | 256kB 9.3MB/s eta 0:00:01\r\u001b[K |████████████████████████████▏ | 266kB 9.3MB/s eta 0:00:01\r\u001b[K |█████████████████████████████▎ | 276kB 9.3MB/s eta 0:00:01\r\u001b[K |██████████████████████████████▍ | 286kB 9.3MB/s eta 0:00:01\r\u001b[K |███████████████████████████████▍| 296kB 9.3MB/s eta 0:00:01\r\u001b[K |████████████████████████████████| 307kB 9.3MB/s \n",
787 | "\u001b[?25hRequirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from lightfm) (1.18.5)\n",
788 | "Requirement already satisfied: scipy>=0.17.0 in /usr/local/lib/python3.6/dist-packages (from lightfm) (1.4.1)\n",
789 | "Requirement already satisfied: requests in /usr/local/lib/python3.6/dist-packages (from lightfm) (2.23.0)\n",
790 | "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests->lightfm) (3.0.4)\n",
791 | "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests->lightfm) (1.24.3)\n",
792 | "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests->lightfm) (2020.6.20)\n",
793 | "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests->lightfm) (2.10)\n",
794 | "Building wheels for collected packages: lightfm\n",
795 | " Building wheel for lightfm (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
796 | " Created wheel for lightfm: filename=lightfm-1.15-cp36-cp36m-linux_x86_64.whl size=709137 sha256=33ac00a0cdcb0393e76e2872cad626099fd1a69033af124ee3163a08930640ab\n",
797 | " Stored in directory: /root/.cache/pip/wheels/eb/bb/ac/188385a5da6627956be5d9663928483b36da576149ab5b8f79\n",
798 | "Successfully built lightfm\n",
799 | "Installing collected packages: lightfm\n",
800 | "Successfully installed lightfm-1.15\n"
801 | ],
802 | "name": "stdout"
803 | }
804 | ]
805 | },
806 | {
807 | "cell_type": "code",
808 | "metadata": {
809 | "id": "o6E15CB99xr9",
810 | "outputId": "85eb9969-cbdb-4c35-d3dd-048c33280e59",
811 | "colab": {
812 | "base_uri": "https://localhost:8080/",
813 | "height": 34
814 | }
815 | },
816 | "source": [
817 | "# from tutorial at https://github.com/lyst/lightfm\n",
818 | "from lightfm import LightFM\n",
819 | "from lightfm.evaluation import precision_at_k\n",
820 | "\n",
821 | "# Load the MovieLens 100k dataset. Only five\n",
822 | "# star ratings are treated as positive.\n",
823 | "#data = fetch_movielens(min_rating=5.0)\n",
824 | "\n",
825 | "# Instantiate and train the model\n",
826 | "model = LightFM(loss='warp')\n",
827 | "model.fit(train.tocoo(), epochs=30, num_threads=2)"
828 | ],
829 | "execution_count": 53,
830 | "outputs": [
831 | {
832 | "output_type": "execute_result",
833 | "data": {
834 | "text/plain": [
835 | ""
836 | ]
837 | },
838 | "metadata": {
839 | "tags": []
840 | },
841 | "execution_count": 53
842 | }
843 | ]
844 | },
845 | {
846 | "cell_type": "code",
847 | "metadata": {
848 | "id": "nzXuYWpBPmby"
849 | },
850 | "source": [
851 | "# Evaluate the trained model\n",
852 | "test_precision = precision_at_k(model, test.tocoo(), k=5)"
853 | ],
854 | "execution_count": 54,
855 | "outputs": []
856 | },
857 | {
858 | "cell_type": "code",
859 | "metadata": {
860 | "id": "Djh5X7d5PqTg",
861 | "outputId": "04139def-9bc7-4d5a-a01f-3f9f5dab383b",
862 | "colab": {
863 | "base_uri": "https://localhost:8080/",
864 | "height": 34
865 | }
866 | },
867 | "source": [
868 | "test_precision # .mean()"
869 | ],
870 | "execution_count": 55,
871 | "outputs": [
872 | {
873 | "output_type": "execute_result",
874 | "data": {
875 | "text/plain": [
876 | "array([0.6, 0. , 0. , ..., 0.2, 0.2, 0.4], dtype=float32)"
877 | ]
878 | },
879 | "metadata": {
880 | "tags": []
881 | },
882 | "execution_count": 55
883 | }
884 | ]
885 | },
886 | {
887 | "cell_type": "code",
888 | "metadata": {
889 | "id": "GhIKv53VSvnr",
890 | "outputId": "9d5945c9-c010-4d11-e788-b07c049a08a4",
891 | "colab": {
892 | "base_uri": "https://localhost:8080/",
893 | "height": 34
894 | }
895 | },
896 | "source": [
897 | "test_precision.mean()"
898 | ],
899 | "execution_count": 56,
900 | "outputs": [
901 | {
902 | "output_type": "execute_result",
903 | "data": {
904 | "text/plain": [
905 | "0.11688069"
906 | ]
907 | },
908 | "metadata": {
909 | "tags": []
910 | },
911 | "execution_count": 56
912 | }
913 | ]
914 | },
915 | {
916 | "cell_type": "code",
917 | "metadata": {
918 | "id": "JmARW1jjS7y6",
919 | "outputId": "91720f41-3e50-45c4-e2d0-63cdd0c8cec5",
920 | "colab": {
921 | "base_uri": "https://localhost:8080/",
922 | "height": 476
923 | }
924 | },
925 | "source": [
926 | "sample_recommendation(model, [3, 9999, 15000], train.tocsr(), np.array(book_labels))"
927 | ],
928 | "execution_count": 57,
929 | "outputs": [
930 | {
931 | "output_type": "stream",
932 | "text": [
933 | "User 3\n",
934 | " Known positives:\n",
935 | " The Atlantis Complex (Artemis Fowl, #7)\n",
936 | " Sentinel (Covenant, #5)\n",
937 | " The Devil Wears Prada (The Devil Wears Prada, #1)\n",
938 | " Recommended:\n",
939 | " The Life and Times of the Thunderbolt Kid\n",
940 | " The Atlantis Complex (Artemis Fowl, #7)\n",
941 | " Beautiful Creatures (Caster Chronicles, #1)\n",
942 | "User 9999\n",
943 | " Known positives:\n",
944 | " City of Glass (The Mortal Instruments, #3)\n",
945 | " The Magicians' Guild (Black Magician Trilogy, #1)\n",
946 | " Bridge to Terabithia\n",
947 | " Recommended:\n",
948 | " City of Glass (The Mortal Instruments, #3)\n",
949 | " Enchanters' End Game (The Belgariad, #5)\n",
950 | " Bridge to Terabithia\n",
951 | "User 15000\n",
952 | " Known positives:\n",
953 | " Enchanters' End Game (The Belgariad, #5)\n",
954 | " The Life and Times of the Thunderbolt Kid\n",
955 | " Beautiful Creatures (Caster Chronicles, #1)\n",
956 | " Recommended:\n",
957 | " Where the Heart Is\n",
958 | " Gone Girl\n",
959 | " Tangled (Tangled, #1)\n"
960 | ],
961 | "name": "stdout"
962 | }
963 | ]
964 | },
965 | {
966 | "cell_type": "code",
967 | "metadata": {
968 | "id": "gDBItHUWQZo0"
969 | },
970 | "source": [
971 | ""
972 | ],
973 | "execution_count": 57,
974 | "outputs": []
975 | }
976 | ]
977 | }
--------------------------------------------------------------------------------
/chapter05/pso_it0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PacktPublishing/Artificial-Intelligence-with-Python-Cookbook/31f15f25c0bdc3286cbac6da75573ce9f069d25a/chapter05/pso_it0.png
--------------------------------------------------------------------------------
/chapter05/pso_it1322.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PacktPublishing/Artificial-Intelligence-with-Python-Cookbook/31f15f25c0bdc3286cbac6da75573ce9f069d25a/chapter05/pso_it1322.png
--------------------------------------------------------------------------------
/chapter05/pso_it1323-1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PacktPublishing/Artificial-Intelligence-with-Python-Cookbook/31f15f25c0bdc3286cbac6da75573ce9f069d25a/chapter05/pso_it1323-1.png
--------------------------------------------------------------------------------
/chapter05/pso_it3-1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PacktPublishing/Artificial-Intelligence-with-Python-Cookbook/31f15f25c0bdc3286cbac6da75573ce9f069d25a/chapter05/pso_it3-1.png
--------------------------------------------------------------------------------
/chapter05/pso_it32-1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PacktPublishing/Artificial-Intelligence-with-Python-Cookbook/31f15f25c0bdc3286cbac6da75573ce9f069d25a/chapter05/pso_it32-1.png
--------------------------------------------------------------------------------
/chapter05/skipgram.dot:
--------------------------------------------------------------------------------
1 | digraph G {
2 | # dot -Tpng skipgram.dot > skipgram.png
3 | rankdir=LR;
4 | in [label="w(t)"];
5 | hidden [label=""];
6 | out1 [label="w(t-2)"];
7 | out2 [label="w(t-1)"];
8 | out3 [label="w(t+1)"];
9 | out4 [label="w(t+2)"];
10 | in -> hidden;
11 | hidden -> out1;
12 | hidden -> out2;
13 | hidden -> out3;
14 | hidden -> out4;
15 | }
--------------------------------------------------------------------------------
/chapter05/skipgram.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PacktPublishing/Artificial-Intelligence-with-Python-Cookbook/31f15f25c0bdc3286cbac6da75573ce9f069d25a/chapter05/skipgram.png
--------------------------------------------------------------------------------
/chapter05/solving-n-queens.md:
--------------------------------------------------------------------------------
1 | iteration: 0
2 |
3 | 
4 |
5 | iteration 0 - best particle: [6 1 2 4 0 0 3 4], score: 23
6 |
7 | iteration: 3
8 |
9 | 
10 |
11 | iteration: 32
12 |
13 | 
14 |
15 | iteration 500 - best particle: [6 1 5 2 0 3 7 2], score: 27
16 |
17 | iteration 1000 - best particle: [6 1 5 2 0 3 7 2], score: 27
18 |
19 | iteration: 1322
20 |
21 | 
22 |
23 | iteration: 1323
24 |
25 | 
26 |
27 | iteration 1323 - best particle: [5 3 1 7 4 6 0 2], score: 28
--------------------------------------------------------------------------------
/chapter06/Optimizing a website.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 11,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "# this is partly adapted from Lilian Weng's implementation\n",
10 | "# https://github.com/lilianweng/multi-armed-bandit\n",
11 | "import random\n",
12 | "\n",
13 | "class Bandit:\n",
14 | " def __init__(self, K=2, probs=None):\n",
15 | " \"\"\"A multi-armed bandit\n",
16 | " \n",
17 | " Parameters:\n",
18 | " -----------\n",
19 | " K - the number of arms\n",
20 | " \"\"\"\n",
21 | " self.K = K\n",
22 | " if probs is None:\n",
23 | " self.probs = [\n",
24 | " random.random() for _ in range(self.K)\n",
25 | " ]\n",
26 | " else:\n",
27 | " assert len(probs) == K\n",
28 | " self.probs = probs\n",
29 | "\n",
30 | " self.best_probs = max(self.probs)\n",
31 | "\n",
32 | " def play(self, i):\n",
33 | " \"\"\"Playing the i-th machine.\n",
34 | " \n",
35 | " Returns a reward as 1 or 0.\n",
36 | " \"\"\"\n",
37 | " if random.random() < self.probs[i]:\n",
38 | " return 1\n",
39 | " else:\n",
40 | " return 0"
41 | ]
42 | },
43 | {
44 | "cell_type": "code",
45 | "execution_count": 316,
46 | "metadata": {},
47 | "outputs": [],
48 | "source": [
49 | "class Agent:\n",
50 | " def __init__(self, env):\n",
51 | " \"\"\"\n",
52 | " An abstract agent\n",
53 | " \n",
54 | " Parameters:\n",
55 | " -----------\n",
56 | " env - an environment (a bandit)\n",
57 | " \"\"\"\n",
58 | " \n",
59 | " self.env = env\n",
60 | " self.listeners = {}\n",
61 | " self.metrics = {}\n",
62 | " self.reset()\n",
63 | " \n",
64 | " def reset(self):\n",
65 | " for k in self.metrics:\n",
66 | " self.metrics[k] = []\n",
67 | "\n",
68 | " def add_listener(self, name, fun):\n",
69 | " \"\"\"Add a listener to record a metric\n",
70 | " \"\"\"\n",
71 | " self.listeners[name] = fun\n",
72 | " self.metrics[name] = []\n",
73 | " \n",
74 | " def run_metrics(self, i):\n",
75 | " \"\"\"Calculate metrics after an action i\"\"\"\n",
76 | " for key, fun in self.listeners.items():\n",
77 | " fun(self, i, key)\n",
78 | "\n",
79 | " def run_one_step(self):\n",
80 | " \"\"\"A single choice.\"\"\" \n",
81 | " raise NotImplementedError\n",
82 | " \n",
83 | " def run(self, n_steps):\n",
84 | " \"\"\"plays n_steps of choices \n",
85 | " \"\"\"\n",
86 | " raise NotImplementedError"
87 | ]
88 | },
89 | {
90 | "cell_type": "code",
91 | "execution_count": 317,
92 | "metadata": {},
93 | "outputs": [],
94 | "source": [
95 | "class UCB1(Agent):\n",
96 | " def __init__(self, env, alpha=2.):\n",
97 | " \"\"\"\n",
98 | " The UCB1 agent\n",
99 | " \n",
100 | " Parameters:\n",
101 | " -----------\n",
102 | " env - an environment (a bandit)\n",
103 | " alpha - the alpha parameter\n",
104 | " \"\"\"\n",
105 | " self.alpha = alpha\n",
106 | " super(UCB1, self).__init__(env)\n",
107 | " \n",
108 | " def run_exploration(self):\n",
109 | " \"\"\"Initial exploration\n",
110 | " \"\"\"\n",
111 | " for i in range(self.env.K):\n",
112 | " self.estimates[i] = self.env.play(i)\n",
113 | " self.counts[i] += 1\n",
114 | " self.history.append(i)\n",
115 | " self.run_metrics(i) \n",
116 | " self.t += 1\n",
117 | " \n",
118 | " def update_estimate(self, i, r):\n",
119 | " \"\"\"Incremental update of estimate for arm i\n",
120 | " \"\"\"\n",
121 | " self.estimates[i] += (r - self.estimates[i]) / (self.counts[i] + 1)\n",
122 | "\n",
123 | " def reset(self):\n",
124 | " self.history = []\n",
125 | " self.t = 0\n",
126 | " self.counts = [0] * self.env.K\n",
127 | " self.estimates = [None] * self.env.K\n",
128 | " super(UCB1, self).reset()\n",
129 | " \n",
130 | " def run(self, n_steps):\n",
131 | " \"\"\"plays n_steps of choices\n",
132 | " \n",
133 | " This count does not include the exploration phase.\n",
134 | " \"\"\"\n",
135 | " assert self.env is not None\n",
136 | " self.reset()\n",
137 | " if self.estimates[0] is None:\n",
138 | " self.run_exploration()\n",
139 | " for _ in range(n_steps):\n",
140 | " i = self.run_one_step()\n",
141 | " self.counts[i] += 1\n",
142 | " self.history.append(i)\n",
143 | " self.run_metrics(i)\n",
144 | "\n",
145 | " def upper_bound(self, i):\n",
146 | " return np.sqrt(\n",
147 | " self.alpha * np.log(self.t) / (1 + self.counts[i])\n",
148 | " )\n",
149 | " \n",
150 | " def run_one_step(self):\n",
151 | " \"\"\"A single choice.\n",
152 | " \n",
153 | " Pick the best (optimistic) choice\n",
154 | " (with consideration of upper confidence bounds)\n",
155 | " \"\"\"\n",
156 | " i = max(\n",
157 | " range(self.env.K),\n",
158 | " key=lambda i: self.estimates[i] + self.upper_bound(i)\n",
159 | " )\n",
160 | " r = self.env.play(i)\n",
161 | " self.update_estimate(i, r)\n",
162 | " self.t += 1\n",
163 | " return i"
164 | ]
165 | },
166 | {
167 | "cell_type": "code",
168 | "execution_count": 318,
169 | "metadata": {},
170 | "outputs": [],
171 | "source": [
172 | "from scipy import stats\n",
173 | "\n",
174 | "def update_regret(agent, i, key):\n",
175 | " \"\"\"update regret given an agent and a new action\n",
176 | " \"\"\"\n",
177 | " regret = agent.env.best_probs - agent.env.probs[i]\n",
178 | " if agent.metrics[key]:\n",
179 | " agent.metrics[key].append(\n",
180 | " agent.metrics[key][-1] + regret\n",
181 | " )\n",
182 | " else:\n",
183 | " agent.metrics[key] = [regret]\n",
184 | " \n",
185 | "def update_rank_corr(agent, i, key):\n",
186 | " \"\"\"rank correlation to observe convergence\"\"\"\n",
187 | " if agent.t < agent.env.K:\n",
188 | " agent.metrics[key].append(0.0)\n",
189 | " else:\n",
190 | " agent.metrics[key].append(\n",
191 | " stats.spearmanr(agent.env.probs, agent.estimates)[0]\n",
192 | " )"
193 | ]
194 | },
195 | {
196 | "cell_type": "code",
197 | "execution_count": 319,
198 | "metadata": {},
199 | "outputs": [],
200 | "source": [
201 | "%matplotlib inline\n",
202 | "import pandas as pd\n",
203 | "import matplotlib.pyplot as plt\n",
204 | "\n",
205 | "\n",
206 | "def plot_stats(title):\n",
207 | " df = pd.DataFrame(agent.metrics)\n",
208 | " df['t'] = list(range(len(df)))\n",
209 | " ax = df.plot(x='t', y='regret', legend=False)\n",
210 | " ax2 = ax.twinx()\n",
211 | " ax.spines['left'].set_color('b')\n",
212 | " ax.spines['left'].set_linewidth(1.5)\n",
213 | " ax.tick_params(axis='y', colors='b')\n",
214 | "\n",
215 | " df.plot(x='t', y='corr', ax=ax2, legend=False, color=\"r\")\n",
216 | " ax.spines['right'].set_color('r')\n",
217 | " ax.spines['right'].set_linewidth(1.5)\n",
218 | " ax2.tick_params(axis='y', colors='r')\n",
219 | "\n",
220 | " ax.figure.legend(loc='center', bbox_to_anchor=(0.7, 0.5))\n",
221 | " plt.title(title)"
222 | ]
223 | },
224 | {
225 | "cell_type": "code",
226 | "execution_count": 327,
227 | "metadata": {},
228 | "outputs": [
229 | {
230 | "data": {
231 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAY0AAAEWCAYAAACaBstRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAIABJREFUeJzt3Xl8FPX9+PHXO5uEQMJ9y42EW+5TVKgn1gNrpRVbRMVi61ls9We/ba1a7VerrX611ooneOFVFRFvRTy45ZCbEEHCFSAkJECSze7798dMyHIk2YQkM5u8n49HHjv7mdmZ9052572fz8x8PqKqGGOMMdGI8zoAY4wxscOShjHGmKhZ0jDGGBM1SxrGGGOiZknDGGNM1CxpGGOMiZolDRPzROQqEfmqqpc1xhzLkoYxVUxERojIxyKSJSK7ReR1EWlbxvLNROQtETkgIltE5IqajNeYirCkYUzVawpMAzoDnYBc4Lkyln8cKARaA78AnhCRPtUcozGVYknDxAwRuUNENolIroisEZGflLKcisjNIpIuIntE5EERiTtqmYdEZJ+IfC8i50eUXy0ia91tpIvIdRWNU1XfV9XXVXW/qh4E/gWMKiXWZOCnwJ9VNU9VvwJmARMrul1jaoIlDRNLNgGnA42Bu4EXy2j2+QkwBBgEjAOuiZg3HFgPtAD+DjwjIuLOywQuBBoBVwMPi8ggABHpKCLZZfyV1qx0BrC6lHndgSJV3RBRtgKwmobxJUsaJma4v963q2pYVV8FNgLDSln8AVXNUtUfgEeACRHztqjqU6oaAqYDbXGahlDV91R1kzq+AD7CSVSo6g+q2qSMv5ePDkJE+gF3AreVEmcKsP+oshygYRS7xJgaZ0nDxAwRuVJElhf/sgf64tQWjmdrxPQW4KSI5zuLJ9zmI3AO3ojI+SKywD2JnQ38uIxtlBdvN+B94BZV/bKUxfJwajWRGuGcBzHGdyxpmJggIp2Ap4Abgeaq2gRYBUgpL+kQMd0R2B7FNuoBbwIPAa3dbcwp3obbPJVXxt8vjor3E+CvqvpCGZvdAMSLSGpEWX9Kb84yxlPxXgdgTJSSAQV2g3PCGqemUZrbRGQhTg3iFuCfUWwjEajnbqPIPUF+Lk5ywm3qSilvJSLSDvgM+Jeq/qesZVX1gIj8F7hHRK4FBuCcgzk1iniNqXFW0zAxQVXXAP8A5gO7gFOAr8t4yTvAUmA58B7wTBTbyAVuBl4D9gFX4FzJVFHXAl2BuyJrIsUzReR/ROT9iOWvB+rjnIR/BfiNqlpNw/iS2CBMprYREQVSVTXN61iMqW2spmGMMSZqljSMMcZEzZqnjDHGRM3zmoYIj4jwiNdxGGNMTBF5BJEaP3Z6XtMQYa5I4uj69RM8jcMYY2JKfj6F4bAGVWv0x78v7tNITu5Hbu4Sr8MwxpjYMWYMvb744kBNb9bz5iljjDGxw5KGMcaYqEWVNETYLMJ3IiwXYYlb1kyEj0XY6D42dctFhEdFSBNhpQiDqvMNGGOMKYXIs4hkIrKqlPmCyKOIpCGyEncYgLJUpKbxI1UGqDLEfX4H8KkqqcCn7nOA84FU928K8EQFtmGMMabqPA+MLWN+hY/XJ9I8NQ5nLALcx0siymeooqosAJqIUOr4yMYYY6qJ6jwgq4wlxgEzUFVUFwBNKGM8e4g+aSjwkQhLRZjilrVWZYc7vRN3EBugHUeOZZDhlhljjKlCCZCAyJKIvynlv+oIFT5eR3vJ7WmqbBOhFfCxCOsiZ6qiIlTohg83+UwBegSDFXmlMcYYgCAEUR1S/pJVJ6qkoco29zFThLdwhtjcJUJbVXa4zU+Z7uLbOHIAnPZu2dHrnAZME2FuQgKjT+RNnBBVCIchEPAsBGOqXGYmtGwJUtoYVdUoPx+2uj9eAwHnO1Z8E3HkY0Wno1k2HIY5c2DMGEgpd+iTEomJcMopNbe/8vIgGISmTSv0slXbcvh0bSahcJjx+w6W/4LyRXW8jlRu0hAhGYhTJdedPhe4B2ecgUnA/e7jO+5LZgE3ijATGA7kRDRj+U+XLrBlC/zpT87jrbfCXXfBBRfAeecdu3ybNrB/P7So1Aigjqws50MD8NVXzod8xAjneUoKnHkmxMdD27ZV9yFWhZ07nQ8qOF8u1SOTZUoKNGtWNdsrzerVsGCB896Kbd0K7dpBXDVfAd6zJ2zeDAcOQMOGUK/e8ZeLj4dBg0r2TTgMf/5zyQGqWNeu0KdP2dvcvds5eANkZ8PX7hAgw4c7n6GFC51lUlKcz9wdd8C//w0JCc4++eILePppmDGj9M9CTg6sWVPy/L334L77Sp73LWusqmqw6vgX6tSou++u3Otatar+zyE430VA27RxmmiO004TViUUdmbE5+wjEAzSukETJrifgyaHqmRE4FnAjYgcPl6jWubxutxuREToCrzlPo0HXlblPhGa4wxW0xFnDOafqZIlggD/wjljfxC4WpVSb/cWYW5KypDRntwRvmULdO5cudcePOgcfEpLHps2wY03Ogfpli3h+echNxfmzYOf/jS6bTRtCh07OtMrVjiP/fqVfvBISIAnnoDFi+HJJ0vKi18bjXHjnANkenrJl/+CC6Kric2f7ySDyH26eTNs2wajRjnPh9RoTfrEFP9v9+wpKUtwu7upiTbVZs2cHxjg/BJudPRQ4q7I+CKddJKTUM49t2ZrHHPmOJ/d88+HAQOgcWNn+8UxRD5WdDqaZdetgx49on/PqnDxxc50z55wxhmVe98VoEuXsn97JnNa9znmt8jxtMndQ6sD+yjoN4AurRrSLDkR3n2XXjt25K1VbVjqC0VeAcbgjHO/C/gL4HyIVf+DyDHHa1TLPBj7ou8pz5LGl186H5Bf/9o5GM+YAcuWwTXXwMCBxx4ob7rJSRZHq1cPCgqc6YQE+PnP4cUXj1ymYUMnaRQ7/XRn+8W2bHF+4X70EYRC8MEHRx6Y3nErcn37wsknHxtDYSG8//6RZePGOY+zZzs1pGHDnAQg4rzP/HwYOdJZZv9+mDr1+PupOixc6DyqwqJFzvtKTq6+7X3yibNPv/zSSWz33eccVI/no4+cfVPswAH4739h+/aSJo9du2D58rK3WVjobHfECCcBqDo/JgIB58dAIFByoJs82fkf5OTA2LFObeyKK5xmpg8/hIkTy95Wx47OAbpY587OgdNE5+BBpxY8ZEiNJNj/fX8tT36RzoiuzTi7V2vndgncjwMQFycE4oQuzZNpUM9pEGqRkkj7pg1KVuJ0I1J20qgGtTtpbNjg/Oro3RtmznQOypMmOb+CPv0ULrzQWW7hQueAGo3du50qbKTf/x4+/tj5kr722rHLT53qNEeJOMlo9GgnWS1a5HzRExPL325GBrzxBtxyS+kf6smTIS0NuneHCROcZq6K2L/fqQkVfyb69HGSWV5e2a8rtnWrU6tKSiopKyhwahpduzrPr7rKaYK57baKxWaMj23YlcsbSzNYu2P/EeXBUJgtew9ysDBEOKwUhZWQKoVFYX4xvCP3XtLX/bFfCZY0qjhpHDoEDRqUvxw41fvmzaNf94ED8NlnTjLKyoKhQ0vmrVnjJIkRI5zmqeL2bGNMzNqadZCP1+yioChMWJWwe/APKxwqLGL6/C2Ew0qfkxoRiDsyCdRPDNC1RQoJgTgCcU4t4qTG9fnliE7HLFshHiUNX/RyWy3cE02AU90fNw5+8xunSeKmm0raMcs7kXk8yclw0UXO9NFNRb17O80JxpiYoarc//46vkrbQ6i4RhBWisJhQiFlx/78Us89xAkM69KMh38+gLaN69ds4B6ofUmjoAAeftg5+Qbw7rslzVC//rV3cRljfOv1JRk8OS+doZ2b0rRBIvEBIT4ujnj33EK7pvW5bHB7WqTUI06EOIFAnFS+aSmGxWbSyM52TkgGAs55i8mT4e9/d85N3HpryXJxcc45BGNMnXOwsIiicEn1IDe/iP8uzSC3oIiQW5MorlW8s3wbI7s256VrhxN3Ik1GdUDsJY2DB50T2dde61xx89vfOuXFl3RGWrnSudbdGFNnfLlxN9PmpfN12h7Cx2lSSkqIIz4ujoBbiwjECSe3TOEfP+tvCSMKsZc0vvrKeXz66bKXa9GicucrjDExSVWZMX8Lf5m1GoDxg9vTo82R54gHdGjCkM7VfANrLRd7SePQoSOfn322c231hg3OyekGDZz7JObN8yY+Y0yFhMLKK4t+YOu+gxwqDJFXUOT2CuJcnaQ408V3Tita0muIW64Kew8UsOyHbJITA3z2+zG0bpRU9oZNpcRu0njiCRg//viXyv7sZzUbkzGmUhZvzuJ3r63gh6yDJASE5HrxJCfGExcHcYdveHMecW98i7wRLs49ES0iJASEP13Qi6tO7Ux8wAYlrS6xmzTGjq3YvRXGGF/5YNUObp65nJYp9bj74j5cObJTnbwaKdbEXtIo7t4hyaqexsSi3PwgLyzYwoMfrmdAhyY8O2koTZOj6BXB+ELsJI20NKd/qOKaRv3afxONMbXB1qyDvLVsG/nBEAVFYV6Yv4XCUJize7XisQmDqJ9owxLEkthJGhs2lCSMlJTouwgxxngmHFZ+++pylm7ZR3yckBgfR8OkeG74UTeuOrWzXeIag2InaRT3Ivv++3DqqSVdVBtjfOuxz9JYumUflw5qxz9/NqD8Fxjfi51LDAoLnceOHUsfV8AY4xtvLcvg4U820LttIx66rL/X4Zgq4v+ahqozvkTxTX2ljbZmjPGNNdv3M/XVFcQJ3D2ujzVD1SL+TxqXX+4MZBM5gpkxxrdyDga54eVvSQzE8eHUM+jSohoH1zI1zt9JY8eOYwc1OpGxuY0xVWpb9iHeWJLBoWCIYChMYVGYZVv3sW3fIV7+1XBLGLWQv5NG5OBGAPfcY5faGuOxQ4UhPli9g7eXbWfZD/vYn19Evfg4EgJxJASEBonxPPSz/tbHUy3l76QxfLgzNnOxG27wLhZj6qilW7JY9kM2BUVh8oMhZq3Yzpa9B+nUvAFndG/Jred0p2vLFK/DNDXE30mj4VGjGDazXy7G1ITsg4V8tHoXLy3cwoqMnCPmpbZKYdrEwZzdq7Wd4K6D/J00QiHn8tr9++HOO72Oxpha5YX5m1m7M5dQqHh40zBFYWVPXgELv89CFXq2achvz05l4ohOJNeLp158nPUPVcf5O2m8+KLzWNrgvMaY48o5GGTx5izmfLeDzXsPsPdAIUUhZ8zropByKBjiYGEIgDaNkgjECfEBZ0Ci+DjhkgHtmDCsI0M7N7UkYY7g36QRCnkdgTExoaAoxMxFW9mbV0BhSAmGwrz5bQbZB4MADO7UlAEdmhwe89oZ/1poUC+e34w5mUZJ1ruCiZ5/k8aiRV5HYIyvhcPKvI27efzzNBZv3gdAYnwciYE4OjRrwAM/7Uf/9k1o09h6hDZVx79JY/16ryMwxhOqyvpdueTmF5EfDLEnr4BgkRJ2R7LbnVvAB6t3sj37EDmHgjRLTuR/ftyTX53e1ZqSTLXzX9LIzIT77oNHH/U6EmNqVDiszJi/mQ9W72RBelaZyzZMiuei/icxvEszzu/blsT42OlGzsQ2/yWNm2+GV1/1Ogpjaty7K7dz17traJgUz3VndOX01JYkxsfRqH48DZMSiBMQhDiBxg0SqBdv41CYmue/pHH0CfBPP/UmDmNq2GOfpdGxWQM++91oG+Pa+Jb/PpmBiF9PHTrAmWd6F4sxNeSJuZtIy8zj50M7WMIwvua/T2dk31K/+513cRhTQ2av3M4DH6yjRUoik0/r4nU4xpQp6qQhQkCEZSLMdp93EWGhCGkivCpColtez32e5s7vXKGITjutZLpTpwq91JhYsO9AId9l5LDo+yxuf2MFN768DIBpVw4hKcHOU5gqJjIWkfWIpCFyx3Hmd0Tkc0SWIbISkR+XtbqKnNO4BVgLFA+b9wDwsCozRfgPMBl4wn3cp0o3ES53l/t51FuJvGSwXbsKhGeM/6gqobDTTUfm/gLeXr6Nf32WRmEofHiZ8/u24Xfn9qBbK+v0z1QxkQDwOHAOkAEsRmQWqmsilvoT8BqqTyDSG5gDpf/YjyppiNAeuAC4D7hVBAHOBK5wF5kO3IWTNMa50wBvAP8SQVSJri+QsPtleuutY7tGNyaGrN+Zyw0vf0taZt4R5U0bJPDgxf1oXD+BLi2S6dTcxpww1WYYkIZqOgAiM3GO0ZFJQympDDQGtpe1wmhrGo8AtwPF3c42B7JVKXKfZwDF1YJ2wFYAVYpEyHGX3xPVloqTxrBhUYZmTM37cuNupn+zmWDIqUkU/xWFwxSGwmzPzifrQCGJgTiuOrUzrRslUS8+jtNTW9ClRbKd7DY15fDx2JUBDD9qmbuAjxC5CUgGzi5rheUmDREuBDJVWSrCmIpEW856pwBTgB7BYMSM4qQRZ18q4y978wpYkJ7Fou/38uLCHwiI0OukRsTHyeEO/+olxNMoTjilXRNaNazHpYPaWU3CVJsESEBkSUTRNFSnVXA1E4DnUf0HIiOBFxDpi2r4eAtHU9MYBVwswo+BJJxqzP8BTUSId2sb7YFt7vLbgA5AhgjxONWdvUevVJVpwDQR5iYkMPrwDEsaxmdCYeWvs9fwwoIthMJKnMCEYR35/bk9aJpsY9Yb7wQhiOqQMhYpPh4XizxWF5sMjAVAdT4iSUALIPN4Kyw3aajyB+APAG5N4/eq/EKE14HLgJnAJOAd9yWz3Ofz3fmfRX0+A0qShvWhYzyQV1DEtHnpLEjfS+b+fHbnFnAoGCKscGbPVtxyViqdmyfTuIH1DGtiwmIgFZEuOMnickrORRf7ATgLeB6RXjiVg92lrfBE7gj/f8BMEe4FlgHPuOXPAC+IkAZkuUFGr3jsDKtpmGqiqhSGwoTDEFK3I8CwkrHvENe9sJRt2Yfo174x3Vs35OxerWmQGKBb64Zc3P8kr0M3pmJUixC5EfgQCADPoroakXuAJajOAn4HPIXIVJyT4lehpQ9iVKGkocpcYK47nY5zZv7oZfKB8RVZ7xGsecpUs1/NWMona3cdd16jpHje+PVIhnS2oYVNLaE6B+cy2siyOyOm1+CchoiK//qesqRhqtHsldv5ZO0uBnZswnl92hAQQQQCcUKcCKO6Nadbq4blr8iYOsqShqkzwmHl0U83ktoqhdevG2mXvRpTCf771ljSMNUgFFZumrmMDbvyuOFH3SxhGFNJ/vvmWNIw1WDR91m8t3IHP+rRkgv7tfU6HGNilv+OzJY0TDWYuyGThIDw2BWDrJZhzAmwcxqm1jpUGGLvgQI27MrlyS/SGdm1OSn1/PeRNyaW+O8bZEnDVNKevAIKisKEw8qyrdn8z3+/I6+g6PD8ywa39zA6Y2oHSxom5q3alsMf3/qOFRk5R5QnJwb46yV9adoggV5tG3FyS+t63JgT5d+kYd2ImCgcKgwx9dXl7Mkr4PaxPWienEicOB0I9mvf2O65MKaK+S9pqFrCMOU6VBhi1/58Hv88jbTdeTx31VDG9GjldVjG1Hr+SxrhsDVNmTKl787jsv/MJ+tAIQC/GXOyJQxjaoglDeN7e/IKWJiexVdpe/hw9U6yDhTSLDmRBy/rR5vGSYw6uYXXIRpTZ1jSML6THwzx7Q/7eHfFDuZv2sPmvQcBaJAY4IzUlvRo05AL+rWle2s7X2FMTbOkYTz3+pKtLNm8j8JQmMKiMN9ty+GHrIOIwKknN+eSge0Y1a0F/ds3ITHePhvGeMmShqlx+cEQhwpDAKTvOcBtb6ykUVI8TRokkhAQGtdP4DdjTmbiiE6c1KS+x9EaYyJZ0jA1av3OXC77zzfk5pfcdCcC79x4Gl1a2FjaxvidJQ1To977bge5+UXcPrYHDRICAIw8uYUlDGNihCUNU6NWbM2mZ5uGXD+mm9ehGGMqwX9HZ0satZaqsiIjmwEdmngdijGmkvx3dA6H7Y7wWijnUJBbX1tB9sEg/S1pGBOz/Nc8pWo1jVpmw65cfvn0QnbnFXBO79ac16eN1yEZYyrJf0nj0CEncZhaIWPfQX76xDcAvH39KKtlGBPj/Jc05syBYNDrKMwJyj5YyHvf7eD5rzdTWBTm5V+NsIRhTC3gv6TRqpWd04hhqsp73+1g6qvLCYaUhIBwx/m9GNypqdehGWOqgP+SRjAIAwd6HYWphGAozOTpS5i3YTftm9bn/y4fwKCOTRH7EWBMreHPpJGY6HUUpoJUlbvfXc28Dbu5bnRXbjoz1cbjNqYW8t+3OhiEhASvozBRCoWVWSu28eGqXXyweieXDW7PH87v5XVYxphq4r+kUVhoSSNGbN5zgLvfXc3n63cjAlcM78i94/p6HZYxphr5K2l8+y1kZlrzlM+pKjfPXM67K7YTJ/DXcX24bHAH6icGvA7NGFPN/JU0XnnFeTzjDG/jMMcoLAqzdsd+vty4m9krd7BuZy5XDO/I1ad2JtUGQzKmzvBX0giHITkZLrvM60iM60BBEQ98sI6Zi7ZSGAoD0L11Cr8efTK/O7c7CQG7e9+YusRfScO6EPGdJ+elM2P+Fs7q2YoxPVsxpntLOjRr4HVYxphoiYwF/g8IAE+jev9xlvkZcBegwApUryhtdeUmDRGSgHlAPXf5N1T5iwhdgJlAc2ApMFGVQhHqATOAwcBe4OeqbI7qzVlnhb6zIH0vXVsk8/SkIXa/hTGxRiQAPA6cA2QAixGZheqaiGVSgT8Ao1Ddh0irslYZzc/6AuBMVfoDA4CxIowAHgAeVqUbsA+Y7C4/Gdjnlj/sLhcdq2n4Rjis/PntVSz6Potz+rS2hGFMbBoGpKGajmohzg/9cUct8yvgcVT3AaCaWdYKyz1Cq6Kq5LlPE9w/Bc4E3nDLpwOXuNPj3Oe4888SIbojjtU0fOPVJVt5YcEWLuzXllvOSvU6HGNM5bQDtkY8z3DLInUHuiPyNSIL3OasUkX1s16EgAjLgUzgY2ATkK1K8UDPkYEcDtKdn4PThHX0OqeIsAQYfLh/QlVLGh7bnx/kxQVbuGvWaoZ1bsZjEwbSINFfp76MMY4ESEBkScTflEqsJh5IBcYAE4CnECm1d9GojgaqhIABIjQB3gJ6ViKwo9c5DZgmwtyEBEYXF1rzlHeW/bCPCU8tID8YpnPzBvz7l4OsWcoYHwtCENUhZSyyDegQ8by9WxYpA1iIahD4HpENOElk8fFWWKEjtCrZwOfASKCJyOGkExnI4SDd+Y1xToiXz5qnPJEfDLFm+37+9PYqgiFl2sTBfDj1DFqk1PM6NGPMiVkMpCLSBZFE4HJg1lHLvI1TywCRFjjNVemlrTCaq6daAkFVskWoj3MW/gGc5HEZzomVScA77ktmuc/nu/M/UyW6UZWsplFjCopCvLdyB5+uy+SztZkcCoYAuOP8npxrI+sZUzuoFiFyI/AhziW3z6K6GpF7gCWoznLnnYvIGiAE3IZqqT/0o2meagtMFyGAUzN5TZXZIqwBZopwL7AMeMZd/hngBRHSgCyczBYdq2nUCFXnyqjXlmQAcHav1pzTuxXdWzdkYEcb98KYWkV1DjDnqLI7I6YVuNX9K1e5SUOVlcAxA1yoko5zOdfR5fnA+Gg2fryNWU2j+uzOLWDDrlw+XrOL15ZkcP2Yk7nujJNp3MA6iDTGRMdfl8VYTaPabNqdx0+f+Ibsg86lar8c0ZHbzuthJ7qNMRXir6RhNY1q8eHqnfx25nLqJcQx/ZphtG5Ujx6tG1rCMMZUmL+ShtU0qtSHq3cyf9Nenv9mM+2a1Odvl57C6O4tvQ7LxLhgMEhGRgb5+fleh+KppKQk2rdvT0IdG//HX0nDbu6rMmu27+f6l74lTqBpgwSmXTmYPic19josUwtkZGTQsGFDOnfuXGdrq6rK3r17ycjIoEuXLl6HU6P8lzSsearSMvfn8/rSDNbu2M/slTsIxAkfTx1N5xbJXodmapH8/Pw6nTAARITmzZuze/dur0Opcf5KGtY8VSmZ+/P59odsnpibxoqMHOonBPjJwHZcPaqzJQxTLepywihWV/eBv5KG1TQqbNkP+7h2+hL2HigkTuCxCQO5qP9JXodlTK2wefNmvvnmG664otThJeocfx2hraZRYffMXkN+MMSTEwcz/w9nWcIwdY6qEg6HK/36oqKiUudt3ryZl19+udLrro38lTSsplGm/GCInENB9uQVsDMnnzeXZrDsh2x+ObIT5/VpQ+tGSV6HaEyN2Lx5Mz169ODKK6+kb9++vPDCC4wcOZJBgwYxfvx48vKc0RzmzJlDz549GTx4MDfffDMXXnghAHfddRcTJ05k1KhRTJw4kVAoxG233cbQoUPp168fTz75JAB33HEHX375JQMGDODhhx/27P36ib+apzZvhuxsr6PwpQ9W7eCGl5cRCh/bjdeZPcocaMuYanP3u6tZs31/la6z90mN+MtFfcpdbuPGjUyfPp1u3bpx6aWX8sknn5CcnMwDDzzAP//5T26//Xauu+465s2bR5cuXZgwYcIRr1+zZg1fffUV9evXZ9q0aTRu3JjFixdTUFDAqFGjOPfcc7n//vt56KGHmD17dpW+x1jmr6SxcKHXEfiOqvLFht38vze/o1XDekw+rQuJ8XHEx8WRlBDHWT1bWzcgpk7q1KkTI0aMYPbs2axZs4ZRo0YBUFhYyMiRI1m3bh1du3Y9fEnshAkTmDZt2uHXX3zxxdSvXx+Ajz76iJUrV/LGG864cjk5OWzcuJHExMQaflf+56+kYY6wfmcut8xcxrqduXRu3oDHJgzilPZ2r4Xxj2hqBNUlOdm5MlBVOeecc3jllVeOmL98+fKoXl+8jscee4zzzjvviGXmzp1bNcHWIv47gdDzhMd3qhXW7dzPxf/6ip3787ntvB68c+NpljCMOY4RI0bw9ddfk5aWBsCBAwfYsGEDPXr0ID09nc2bNwPw6quvlrqO8847jyeeeIKgO4zohg0bOHDgAA0bNiQ3N7fa30Ms8VdN46ST4LTTvI7Cc+Gwcufbq2mQGODDqWfQqqGd4DamNC1btuT5559nwoQJFBQUAHDvvffSvXt3/v3vfzN27FiSk5MZOnRoqeu49tpC8CX5AAAWDElEQVRr2bx5M4MGDUJVadmyJW+//Tb9+vUjEAjQv39/rrrqKqZOnVpTb8u3xOlK3cMAhLkpKUNG5+YugbZt4aKLIKLdsS763zlreXJeOg9e1o/xQzqU/wJjatDatWvp1auX12FEJS8vj5SUFFSVG264gdTU1Co98Hu6L8aModcXX+StVW1Yk5v1V/OU9T3F3z9Yx5Pz0uneOoXLBrf3OhxjYtpTTz3FgAED6NOnDzk5OVx33XVehxTz/NU8Vcfv09iTV8BTX6YzrEszHr18YJ3tpsCYqjJ16lRrUqpi/jpC1+E7wguKQtz//jqCIeVvPzmFNo3tPIYxxn+spuEDmbn5XPzY1+zcn8+vTu9Ct1YpXodkjDHH5a+kUQdrGjkHg9z48jJ27s/nH+P781M7j2GM8TF//azft8/rCGrc9S8vZdkP+7j3kr6WMIwxvueLpJEYzod4t9Izfbq3wdSgp79M5+u0vdx+Xk9+OaKT1+EYY0y5fJE0ErQQQiHnyemnextMDVm1LYd731tLn5MaMXGkJQxjasLR3aCX1S26OT5/ndN48004/3yvo6gRX6ftAeC5q4eSlBDwOBpjYs+MGTN46KGHEBH69evHX//6V6655hr27NlDy5Ytee655+jYsSNXXXUVSUlJLFu2jFGjRtGoUSM2bdpEeno6HTt2PKbPKlM2fyWNtm3B7XWyNvto9U7+9/11nNwy2boIMbHtt7+FcjoGrLABA+CRR8pcZPXq1dx777188803tGjRgqysLCZNmnT479lnn+Xmm2/m7bffBiAjI4NvvvmGQCDAXXfddUS36KZifNE8dVgtv3Iq51CQa6cvZsoLS+nZpiF3ethDqDGx7LPPPmP8+PG0aNECgGbNmjF//vzDw7JOnDiRr7766vDy48ePJxAoqdFHdotuKsZfNY1abHv2If42Zy2frsvkujO6MvWc7tYsZWJfOTUCv4jsBv14z030/FHT8LjTxOqWlpnLmIfmMnvlDiYM68gfftzLEoYxJ+DMM8/k9ddfZ+/evQBkZWVx6qmnMnPmTABeeuklTq8jF9XUNH/VNGph81ReQRE3vLSMxEAcj14+kDO6t/A6JGNiXp8+ffjjH//I6NGjCQQCDBw4kMcee4yrr76aBx988PCJcFP1fNE1euuk1NE789NgwQIYPtzTeKqKqvLAB+t5d8V2duQcYsY1wzkt1RKGiX2x1DV6dbOu0b1Wi2oaT32Zzn++2ET7pvV5cuIQSxjGmFqh3KQhQgcRPhdhjQirRbjFLW8mwscibHQfm7rlIsKjIqSJsFKEQdX9Jvzm5YU/8Lc56zizZytmThnBOb1bex2SMaauEhmLyHpE0hC5o4zlfoqIIjKkrNVFU9MoAn6nSm9gBHCDCL2BO4BPVUkFPnWfA5wPpLp/U4Anyn1PUQQRKzL2HeTud1fTrVUKf7mot42JYYzxjkgAeBznuNwbmIBI7+Ms1xC4BVhY3irLTRqq7FDlW3c6F1gLtAPGAcUdRU0HLnGnxwEzVFFVFgBNRGhb3nbcwKNazK827srlqucWIwLPXz2UTs3tsj5TO3l9LtQPYmQfDAPSUE1HtRCYiXOMPtpfgQeA/PJWWKFzGiJ0BgbiZKPWquxwZ+0Eittg2gFbI16W4ZYdva4pIiwBBheFKhKFPwVDYcY/OZ+0zDz+9pNTaN+0gdchGVMtkpKS2Lt3b6wcNKuFqrJ3716Skrzt0SEBEhBZEvE35ahFyj8eiwwCOqD6XjTbjPqSWxFSgDeB36qyP7JSoIqKUKFPkCrTgGkizI0PMJpgRV7tP99u2Uf2wSD3X3oKlw6yLs5N7dW+fXsyMjLYvXu316F4Kikpifbtvf2uByGIapnnIMokEgf8E7gq2pdElTRESMBJGC+p8l+3eJcIbVXZ4TY/Zbrl24AOES9v75ZFtaFY9dm6TOLjhAv6RdcSZ0ysSkhIoEuXLl6HYaJT3vG4IdAXmOsef9sAsxC5GNUlx1thNFdPCfAMsFaVf0bMmgVMcqcnAe9ElF/pXkU1AsiJaMaqdXIOBfnz26t4+qvv+VHPVjRMSvA6JGOMKbYYSEWkCyKJwOU4x2iHag6qLVDtjGpnYAFQasKA6Goao4CJwHciFHdn+T/A/cBrIkwGtgA/c+fNAX4MpAEHgaujf3+xJTM3n58/uYDv9xzgp4Pa8+cL7YYnY4yPqBYhciPwIRAAnkV1NSL3AEtQnVX2Co5VbtJQ5StKvyr2rOMsr8ANFQ0EiKnmqa1ZB5n03CJ25uTzyq9GMPLk5l6HZIwxx1Kdg/NjPrLszlKWHVPe6nxyR3hsXYWhqvzu9RVk7i/gsQkDLWEYY+oMnyQNV4zUNF5csIVF32dx+9genG13extj6hB/JY0YkHWgkLvfXUPHZg2YMKyj1+EYY0yNsqRRQZ+u3UVRWHlswkASArb7jDF1i7+Oej5vnlJV3liawUmNk+jXvrHX4RhjTI3zV9LwuZcW/sDC77O4YnhH64jQGFMnWdKI0vd7DvCnt1fRv0MTrh/TzetwjDHGE/5KGj7+9X7ra8tplBTPfZf0JS7Ov3EaY0x18lfS8KmcQ0GWb81m8mld6dvOzmUYY+oufyUNn9Y0vtq4B1UY3Kmp16EYY4ynou4avS7KD4Z4ccEWHv10I6mtUhjWpZnXIRljjKd8kTT8WL/4cPVO/vjWd+zJK+T01BbcM64vifH+qpgZY0xN80XSOMwnzVMHCor409urKCwK8+TEwZzXp43XIRljjC/4K2n4wAsLtvDSgi3szStgxjXDOS21hdchGWOMb/gkaXjfy21RKMzjn2/i4U820LpRPR6bMMgShjHGHMUnScPlUfOUqnLzzGXM+W4np6e24LmrhhJv/UoZY8wx/JU0PKCq3PraCuZ8t5NfjujI3Rf3JWA37xljzHHV+Z/T//12G28t28bZvVrzl4v6WMIwxpgy+KumUcPNU/nBEH99bw1DOzflyYmDLWEYY0w56nRN47ttOWQfDPKr07tawjDGmCj4K2nUcE1j7vpMAAZZ9yDGGBMVfyWNGqSqvP/dTjo1b0CLlHpeh2OMMTGhTiaNUFh58MP1pO85wC+G2zjfxhgTrTp3IjyvoIhJzy5i6ZZ9/GRgOyad2rnat2mMMbWFv5JGNftkzS7+/M4qduTkM/Xs7txydqrXIRljTEypM0lj34FCprywhJYN6/Hc1UP5UY9WXodkjDExx1/nNKqpeWr51mwu+ffXhBX+cH4vSxjGGFNJvqhpSDV2WJi+O4/Lp82nSf1Enpk0hDN7WsIwxpjK8kXSOKwaahpvL99OfjDMizcNp1urlCpfvzHG1CX+ap6qBl9s2E3/9o0tYRhjTBWo1UnjvZU7WLE1mwv7neR1KMYYUyuUmzREeFaETBFWRZQ1E+FjETa6j03dchHhURHSRFgpwqAKRVOFzVN5BUX8vzdXMqhjEyaO7FRl6zXGmJgiMhaR9YikIXLHcebfisgaRFYi8ikiZR4wo6lpPA+MParsDuBTVVKBT93nAOcDqe7fFOCJKNZf5UJh5b731pBXUMSfL+xNUkLAizCMMcZbIgHgcZxjc29gAiK9j1pqGTAE1X7AG8Dfy1pluUlDlXlA1lHF44Dp7vR04JKI8hmqqCoLgCYitC1vG1WpKBRm3ONf8cqirVw6sB0DO1pnhMaYOmsYkIZqOqqFwEyc43QJ1c9RPeg+WwC0L2uFlT2n0VqVHe70TqC1O90O2BqxXIZbdgwRpoiwBBgcKiopPFFvfpvBqm37ueWsVP7xs/4nvD5jjPGrBEhAZEnE35SjFon6mOyaDLxf1jZP+JJbVVQqcaOFKtOAaSLMDcQzmqJyX1Ku1dtz+J+3VtGuSX1u+FE3xKMxx40xpiYEIYjqkCpZmcgvgSHA6LIWq2xNY1dxs5P7mOmWbwM6RCzX3i2Lzgke5D9es4uwKq9eN4LE+Fp9YZgxxkQjumOyyNnAH4GLUS0oa4WVPbLOAia505OAdyLKr3SvohoB5EQ0Y1W7pVv20aN1Q9o3bVBTmzTGGD9bDKQi0gWRROBynON0CZGBwJM4CSPz2FUcKZpLbl8B5gM9RMgQYTJwP3COCBuBs93nAHOAdCANeAq4Pso3dsK+TtvDlxv3MKJr85rapDHG+JtqEXAj8CGwFngN1dWI3IPIxe5SDwIpwOuILEdkVilrA6I4p6HKhFJmnXWcZRW4obx1lqqSzVOLN2dx5bOL6NYqhanndK/05o0xptZRnYPzgz6y7M6I6bMrsjqfNPxXvsPCjH0HuXb6ElLqxfOP8f1pXD+hCuMyxhgTyV8dFlZQYVGYG15eRjisvHvTaXRukex1SMYYU6v5K2lUsHnq+W++Z8XWbP51xUBLGMYYUwP80TxVidapYCjMY5+mMbBjEy44pUZvOjfGmDrLH0mjEpZvzSa3oIgpp3e1m/iMMaaG+CtpVODg/9rircQJnHpyi2oMyBhjTCRfJI2K1hPmbdjN60sz+PnQDjRuYFdLGWNMTfFF0jgsiprG/vwgN89cRmqrFP584dE9/BpjjKlO/rp6Kgr/mbuJ7INBXpw8nAaJMRe+McbENH/VNMqxaXce/567iUsHtqNvu8Zeh2OMMXWOv5JGOc1Tf3tvLYnxcdw2tkcNBWSMMSaST5JG+Tdq5OYHmbthN9eM6kLbxvVrICZjjDFH80nSKN/SLfsIhZXTutkltsYY4xV/JY1SmqcOFBTx9JffkxAQBnZsUsNBGWOMKeavpFGKv8xazdeb9nDnhb1JrmdXTBljjFf8lTSOU9NYumUfbyzNYPKoLkwc2bnmYzLGGHOYv5LGUQ4UFDFlxhIArv9RN4+jMcYY4+uk8cnaXew9UMgjPx9As+REr8Mxxpg6z19J46jmqU/WZtIipR4X9z/Jo4CMMcZE8lfSiKCqLEzfy6knNycuzro+N8YYP/Bt0li2NZvM3AKGdmnmdSjGGGNc/koabvNUOKzc/sZKmiUnclE/G5XPGGP8wl9Jw/X452mkZeZx54W9adLAToAbY4xf+C5prNm+n39+soGL+p/EuAF2AtwYY/zEX0lDhAc/XEejpATuHdfXxv42xhif8UXSELeX27z8IPM27uEXwzvaMK7GGONDvkgaxb5Oz3J6sk21nmyNMcaPfJU0Xl28lbaNkxjcqanXoRhjjDkOXyWNVdtyuOnMVOrFB7wOxRhjzHH4ImmE1TmncVav1kwY1sHjaIwxxpTGH0kj7Dz+5eI+dsWUMcb4WLUkDRHGirBehDQR7ihv+eKaRlKCNUsZY0yVEhmLyHpE0hA59ngsUg+RV935CxHpXNbqqjxpiBAAHgfOB3oDE0ToXdZrNOLFxhhjqojIMcdjRI4+Hk8G9qHaDXgYeKCsVVZHTWMYkKZKuiqFwExgXFkvaBvaWQ1hGGNMnTcMSEM1HdXSjsfjgOnu9BvAWWWdJ6iOpNEO2BrxPMMtO4IIU0RYAgw+EJcCt9wCrVpVQzjGGFM7JUACIksi/qYctUg0x+OSZVSLgBygeWnbjD/hqCtJlWnANBHm7mpw8mgeecSrUIwxJiYFIYjqkJrcZnXUNLYBkdfNtnfLjDHG1Kxojscly4jEA42BvaWtsDqSxmIgVYQuIiQClwOzqmE7xhhjyrYYSEWkCyKlHY9nAZPc6cuAz1BVSlHlzVOqFIlwI/AhEACeVWV1VW/HGGNMOVSLEDnieIzqakTuAZagOgt4BngBkTQgCyexlKpazmmoMgeYUx3rNsYYUwGqxx6PVe+MmM4Hxke7Ol/cEW6MMSY2WNIwxhgTNUsaxhhjomZJwxhjTNQ8u7kvwvK8vOVniMghrwPxiXigyOsgfMD2QwnbFyVsX7gSICkMDWp6u1LG5bg1F4TIEq3huxr9yvaFw/ZDCdsXJWxflPBqX1jzlDHGmKhZ0jDGGBM1vySNaV4H4CO2Lxy2H0rYvihh+6KEJ/vCF+c0jDHGxAa/1DSMMcbEAEsaxhhjouZ50hCRsSKyXkTS5HiDnsc4EXlWRDJFZFVEWTMR+VhENrqPTd1yEZFH3X2xUkQGRbxmkrv8RhGZdLxt+Z2IdBCRz0VkjYisFpFb3PI6tT9EJElEFonICnc/3O2WdxGRhe77fVWcrqwRkXru8zR3fueIdf3BLV8vIud5845OnIgERGSZiMx2n9fJfSEim0XkOxFZLiJL3DJ/fT9U1bM/nK56NwFdgURgBdDby5iq4T2eAQwCVkWU/R24w52+A3jAnf4x8D4gwAhgoVveDEh3H5u60029fm+V2BdtgUHudENgA85g93Vqf7jvJ8WdTgAWuu/vNeByt/w/wG/c6euB/7jTlwOvutO93e9MPaCL+10KeP3+KrlPbgVeBma7z+vkvgA2Ay2OKvPV98PrmsYwIE1V07X0Qc9jmqrOw+mjPlLkQO7TgUsiymeoYwHQRETaAucBH6tqlqruAz4GxlZ/9FVLVXeo6rfudC6wFmd84jq1P9z3k+c+TXD/FDgTeMMtP3o/FO+fN4CzRETc8pmqWqCq3wNpON+pmCIi7YELgKfd50Id3Rel8NX3w+ukEc2g57VRa1Xd4U7vBFq706Xtj1q3n9xmhYE4v7Lr3P5wm2OWA5k4X+pNQLaqFneREfmeDr9fd34O0JxasB9cjwC3A2H3eXPq7r5Q4CMRWSoiU9wyX30//ND3VJ2mqioideq6ZxFJAd4Efquq+50fio66sj9UNQQMEJEmwFtAT49D8oSIXAhkqupSERnjdTw+cJqqbhORVsDHIrIucqYfvh9e1zSiGfS8NtrlViNxHzPd8tL2R63ZTyKSgJMwXlLV/7rFdXZ/qGo28DkwEqd5ofiHXOR7Ovx+3fmNgb3Ujv0wCrhYRDbjNE+fCfwfdXNfoKrb3MdMnB8Tw/DZ98PrpLEYSHWvlCht0PPaKHIg90nAOxHlV7pXRYwActxq6YfAuSLS1L1y4ly3LKa4bc/PAGtV9Z8Rs+rU/hCRlm4NAxGpD5yDc37nc+Ayd7Gj90Px/rkM+EydM56zgMvdK4q6AKnAopp5F1VDVf+gqu1VtTPO9/8zVf0FdXBfiEiyiDQsnsb5XK/Cb98PH1wt8GOcq2g2AX/0Op5qeH+vADuAIE7b4mScNthPgY3AJ0Azd1kBHnf3xXfAkIj1XINzci8NuNrr91XJfXEaTpvtSmC5+/fjurY/gH7AMnc/rALudMu74hzo0oDXgXpueZL7PM2d3zViXX9098964Hyv39sJ7pcxlFw9Vef2hfueV7h/q4uPh377flg3IsYYY6LmdfOUMcaYGGJJwxhjTNQsaRhjjImaJQ1jjDFRs6RhjDEmapY0jCmFiDQRkeu9jsMYP7GkYUzpmuD0qmqMcVnSMKZ09wMnu2MbPOh1MMb4gd3cZ0wp3J54Z6tqX49DMcY3rKZhjDEmapY0jDHGRM2ShjGly8UZltYY47KkYUwpVHUv8LWIrLIT4cY47ES4McaYqFlNwxhjTNQsaRhjjImaJQ1jjDFRs6RhjDEmapY0jDHGRM2ShjHGmKhZ0jDGGBO1/w+j5lqTudx8dgAAAABJRU5ErkJggg==\n",
232 | "text/plain": [
233 | ""
234 | ]
235 | },
236 | "metadata": {
237 | "needs_background": "light"
238 | },
239 | "output_type": "display_data"
240 | }
241 | ],
242 | "source": [
243 | "random.seed(42.0) # int(time.time())\n",
244 | "bandit = Bandit(20)\n",
245 | "agent = UCB1(bandit, alpha=2.0)\n",
246 | "agent.add_listener('regret', update_regret)\n",
247 | "agent.add_listener('corr', update_rank_corr)\n",
248 | "agent.run(5000)\n",
249 | "plot_stats('alpha=2.0')"
250 | ]
251 | },
252 | {
253 | "cell_type": "code",
254 | "execution_count": 328,
255 | "metadata": {},
256 | "outputs": [
257 | {
258 | "data": {
259 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAY4AAAEWCAYAAABxMXBSAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAIABJREFUeJzt3Xl8VdW5//HPkxAIQyDMIqAgIIOICDhSi3XE4WonetXWkV5sr61We7Xa0Q7en1ZbbW21xWrF1uLUVrnWOk9VKwqKyExAhjATCCSEkOn5/bF3yCFkTk72Psn3/Xqd19l7n3XWfs5+Jec5a+291zJ3R0REpKHSog5ARERSixKHiIg0ihKHiIg0ihKHiIg0ihKHiIg0ihKHiIg0ihKHpDwzu8LM3mrpsiJSMyUOkSQws9PNbJmZFZnZa2Z2eB1l15jZXjMrDB8vtmasIo2lxCHSwsysD/A34AdAL2Ae8Hg9b/sPd+8WPs5KdowizaHEISnDzG42s1VmVmBmS8zsc7WUczO71sxWm9l2M7vTzNKqlbnLzHaa2Sdmdk7C9ivNbGm4j9VmdnUTQv08sNjdn3T3YuBW4BgzG9WEukRiR4lDUskq4BSgB/Bj4M9mNqCWsp8DJgETgAuBqxJeOwFYDvQBfg48aGYWvrYVOB/oDlwJ3G1mEwDM7DAzy6/jcUlYx1HAR5U7c/c9YexH1fHZHjWzbWb2opkd07DDIRINJQ5JGeEv+I3uXuHujwMrgeNrKX6Hu+9w93XAPcDFCa+tdfcH3L0cmAUMAPqH+/iHu6/ywBvAiwTJCndf5+7ZdTz+EtbfDdhVLZ5dQFYtsX4ZGAIcDrwGvGBm2Q0+MCKtTIlDUoaZXWZmCyp/4QNjCVoNNVmfsLwWODRhfXPlgrsXhYvdwn2cY2bvmtmOcB/n1rGP2hQStFgSdQcKairs7m+7+153L3L3/wfkEyYrkThS4pCUEF6V9ADwDaC3u2cDiwCr5S2DE5YPAzY2YB+dgL8CdwH9w308V7mPsKuqsI7Hl8OqFgPHJNTbFRgWbm8Ir+NziUROiUNSRVeCL9RtEJzEJmhx1OZGM+tpZoOB66j/qiaAjkCncB9l4Unz/Vc4hV1V3ep4PBoW/Tsw1sy+YGaZwA+Bhe6+rPoOw2Q02cw6mlmmmd1I0MJ5uwHxikRCiUNSgrsvAX4B/BvYAhxN3V+uzwDzgQXAP4AHG7CPAuBa4AlgJ3AJMKcJsW4DvgDcFtZzAnBR5etm9jsz+124mgXcH5bbAEwFznH3vMbuV6S1mCZykrbGzBwY4e45Ucci0hapxSEiIo2ixCEiIo2irioREWmUyFscZtxjxj1RxyEiklLM7sEsku/OyFscZrxu1nFK584ZkcYhIpJSiospqajwUvdWbwB0aO0d1qRr13EUFMyLOgwRkdRx6qmMfuONPVHsOvKuKhERSS1KHCIi0ihKHCIi0ihKHCIi0ihKHCIi0ihKHCIi0ihKHCIi0ihKHCIiqeaeeyj+6OPIdq/EIanBLHhMngzFxVFHIxKdigq4/noy83dEFkIs7hyXGNu3Dz79aZg2Df7nf6KJ4amnqpbfeQc6d4ZBg6KJRaoUF8P27dCjB2RlNew97rBhA/TtC506JTe+tio/H4CijE5Qui+SEJQ4pG4LF8J77wWP1aujieFf/wqes7PhK1+BoqJo4pAD7dgBH38Mxx8fJPOG2L4dCgpg1CgYMSK58bVhH2wpwhcsgA0HzUbcKpQ4pG65uVXLib/8W9tXvwoPPBDd/kViwt254a7XuX/ltyOLQYkjDj74AI49NujDj4vCQnjlFbj00mB9zRo4/PBIQxJpiN3FpWzM34t70DMG4AQLiYOBux+43fdv94Tl4N0HlwnKVb16YH1UL1vLfj5Yu5MtuxvX3VRcVs6avCJ6dY2uq0+JI2pPPglf+hL8+c/Bcxxs3Qo33gizZwfrw4fDoYdGG5O0GfvKytleWJK0+r86ax5LN+1OWv3JcEj3zEaVH3VIFr26dkxSNPVT4ojakiXB81e+EjzipG9feOONoD86Tq0hSWmXP/Qe765O7hVB540bwPlHDwAq/3Rt/3LlX7KZJSxX/YkbVYUsLFe1XFXGEspQy/b97z1gvwfGM+qQLLp0bMJX8dPR/U8qcUQt8UTvddcFX9ZRW78+6Jq6/noYPTrqaKQNKS4tZ96anZx9VH9OH9U/KfswgzNG96dnhL/I2zoljqj9/OdVy9ddB0OHRheLSJJ9vGEXZRXOtImDOWNMchKHJJ8SR5z06xd1BCKNsiF/L9sK9lFeUUF5BZRVVFBe4ZRVODsKS/hk+56qE8bA4o3BuYfxh2VHFbK0ACWOuNi+Hbp2jToKkQbbtbeUz9z5OiXlFXWW65B2YF/8iUf0ok833fyXypQ4onbkkbBiBfTuHXUkIo2yfkcRJeUVXHvacCYN6UWHNCMtzeiQZqSHjyP7Z5GZkR51qNLC6k0cZjwEnA9sdWdstde+DdwF9HVne3iRwq+Ac4Ei4Ap3Pmj5sNuQjAz4/OejjkJi5IXFm/nZP5ZQUfcP+cjtKysH4LTR/Rk/WF1P7UlDWhwPA78BHkncaMZg4CxgXcLmc4AR4eME4P7wWWpTUqIxewQIrjjatbeUm/+6EDPjMyPjf84ru0sGRx3aPeowpJXVmzjcedOMITW8dDdwE/BMwrYLgUfCGy/fNSPbjAHubGqRaNuiffuUONq5lVsKmD5rHut2VF2afePZI7nmM8MjjEqkdk06x2HGhcAGdz6qdl/YQGB9wnpuuO2gxGHGDGAGMLK0tClRtBElJdBR15u3F+6OO1S4U+GwfmcRv3kth/U7i/j6qcMY0COTLh078PljB0YdqkitGp04zOgCfJegm6rJ3JkJzDTj9YwMpjSnrpSRmwtvvw3/+Z9V2/bsgS5dootJmmT55gIue2gu+8oqqKgImtiV4xdVhIkh2BasVz7XZtLhPfnO1FGtFb5IszSlxTEMGAr7WxuDgA/MOB7YAAxOKDso3JZaysth5crgubrt2+H11+H994NhvhOHhu7bF6bUkQOnT4e5c+Gww6B79+CbprCw4XMZSKt68K1PWLRhV42vLdqwiy2793HG6H4M6tklHFLCSDNIS7P9w02khUNZpJntH+IiLWF7Rnoak4b0ZPQAnSeQ1NHoxOHOx8D+s3ZmrAEmhVdVzQG+YcZjBCfFd0V6fiMnJ/iyfuYZuOYamD+/7vLLlwfP6ek1J42WcvLJB6736ZO8fbUDH6zbyfrw/EBlN1DlKKQV4YJT2UVUtVzZIqhsKQTlq7bd/vwyumd2ICszo8b9Xj3lCG45R0OySPvTkMtxZwOnAn3MyAV+5M6DtRR/juBS3ByCy3GvbKE4m+a22+DNN6Fnz6ptid1EidxhyxYYMADGjQtuxps6teayQ4YEk9cA+6+ZLCyEl16qO+GUlQWznyUOT96hA5zVrF6/dmXJxt1sLSimtNwpKatg/c4ibv9nciazSTOYedkkjhvSKyn1i6SqhlxVdXE9rw9JWHbgmuaH1UImTICHHw6mtty1Kxi2/LHHWqbu2bNh8OCq4TKzsnQ/RpL96d21/ODpRQdtHzuwO3d/afz+LqKgW6hqpNLgUf21cFu4nPgeLEgaGelpunlNpAZt+87xzHCM+8WLYd06OOaYlqv7ootari6p19q8PfzomUUcN6Qnt5w7mo7paWSkp5GRbgzq2YWOHdKiDlGk3WjbiaMknCymY0c46aRoY5FmeXbhJiocfjFtPIf11lVoIlFq2z/TKhNHRs0nNyV1/N9HG5l4eE8lDZEYaNuJo/LOQt1gl9JWbilg2eYC/mPcgKhDERHaeuJI7KqSlPXXDzZgBucercQhEgftI3Gk68qYVPXa8q387o1VnDi0N/26Z0YdjkjqMZuK2XLMcjC7uYbXD8PsNcw+xGwhZufWV2XbPjm+cOGBs9BLbCxYn8+a7XsoLi0nb08Ju/YeOGDZ3pJyXl66hU27ikkz+Olnj4ooUpEUZpYO/BY4k2DswPcxm4P7koRS3weewP1+zMYQ3I83pK5q23biMKu6JFdio6C4lGm/e4fS8qrBmzLSjQ5pBzaA0wy+OHEQV5w8hOH9NCyLSBMcD+TgvhoAs8cIRjFPTBwOVI550wPYWF+lbTtx5OXBCZoOJG4WrM+ntNy5+z+P4cQjetMlowM9uujKN5HGyoAMzOYlbJqJ+8yE9ZpGLK/+pXgr8CJm3wS6AmfUt9+2nTj27oV+8Z8Mp735YG0+ZnDG6P61jgMlIvUrhVLcJzWzmouBh3H/BWYnAX/CbCzutc5BmTqJIy8P3ngjGO/po4/gC1+AoqJgOJGRI2u+cqqoCDp3bv1YpU7z1+1kZP8sJQ2R5GvIiOXTgWBgPvd/Y5YJ9AG21lZp6iSO6iPI/u//Vi1/85vw618Hy2vWwPPPw9e+FiQOzXURC398+xPW5gUj2M5fs4MLNVGRSGt4HxiB2VCChHERcEm1MuuA04GHMRsNZALb6qo0dRJHpQEDgrksTjklmARp9my491649trg9SlTYONG6N07GO1WXVWRe2bBBn78f8G5uO6ZHeiUkc7ZRx0ScVQi7YB7GWbfAF4A0oGHcF+M2U+AebjPAb4NPIDZ9QQnyq/AvY5px1IpcRx9dNBNtWTJgdtnzw6eEydUgmAkXDhwCHNpdYX7yvjW4wsAmPvd0+mvezFEWpf7cwSX2CZu+2HC8hJgcmOqTJ3EYQZHHnnw9sJCePbZquFF9u0LJm4666xgRr7zzmvdOOUA7+Rsxx2+f95oJQ2RNiJ1Eod7zTfyde168ORM06e3TkxSp+LScv7nyY/omJ7GpSep5SfSVqTOkCPukJY64Qrc/fIKdheX8e2zjqRTBw37ItJW1PtNbMZDZmw1Y1HCtjvNWGbGQjP+bkZ2wmu3mJFjxnIzzm6xSCsqNHRIiti5p4TP3/c2M99czUXHDebqKcOiDklEWlBDfsI/TOU1vlVeAsa6Mw5YAdwCYMYYgsu9jgrfc58ZLfNTs7auKomd255bysLcXXzzM8P5/vljog5HRFpYQ+Ycf9PswAGv3HkxYfVd4Ivh8oXAY+7sAz4xI4dgrJR/NztSJY5IFZeWs25HEeUVTnmFU+GVz1QtVzhr8op4an4u13xmGDecNTLqsEUkCVri5PhVwOPh8kCCRFIpN9x2EDNmADOAkaWlNZWoRokjMu/kbOeGJz5i8+7iBpUf1rcr3zxtRP0FRSQlNStxmPE9oAx4tLHvdWcmMNOM1zMymNKQNyhxtD535zt/W0jenn1899xRDO7ZhbQ0I92M9DTDDNLD9bS0YNvoAd3JzNDJcJG2qsmJw4wrgPOB092pvMuwIeOiNI0SR6sqLi3no/X5/P7N1azfsZdfTDuGL0wcFHVYIhIDTUocZkwFbgKmuFOU8NIc4C9m/BI4FBgBvNfsKEGJo5Vd8+gHvLIsGOMsu0sGZ4/VECEiEqg3cZgxGzgV6GNGLvAjgquoOgEvhd/l77rzNXcWm/EEwSQhZcA17pS3SKRKHK1mb0k5ryzbyqhDsrjtc2M5dnBP0tJ07EUk0JCrqi6uYfODdZS/DbitOUHVVrFuAGwdO4qCudqvOHkIEw/vFXE0IhI3qfNNrBsAW01+mDiyu9Qwx4mItHupkzjUVdVq8ouC66OzNZ2riNRAiUMOsjNscfRUi0NEaqDEIQepbHH0VItDRGqgxCEHqTzH0UOJQ0RqkPrzcUiLKSmrYNW2QlZv20PXjukaCl1EaqTEIcxfu5NnF27kxcVb2JC/F4Djh+gyXBGpmRJHO/Xa8q0s2bgbgL9+kMu6vCIG9uzM/5x1JMP7dePoQdn11CAi7ZUSRwyszdvD5+57hz37ylplf07QLZXov04ZyvfO09wZIlK/1EocbfTO8Y837GLHnhIuPn4w3Tu3zgnpNDOmTRzEwJ6dAXQ+Q0QaLHUSRxu+c3zDzuC8wnfPHU1Wpq5kEpF4S52f8G24q2pj/l6yMjsoaYhISlDiiIEN+XsZmN056jBERBpEiSMGcncqcYhI6lDiiJi7By2OnkocIpIa4p049uwJksWjj7bZxPHxhl0UFJdx9MAeUYciItIg8U4cmzcHzz/4QZtNHC8u3kKHNOPMMf2jDkVEpEHqTRxmPGTGVjMWJWzrZcZLZqwMn3uG282MX5uRY8ZCMyY0K7pOnYLnTz6BXbvaZOJYvHEXw/t106RJIpIyGtLieBiYWm3bzcAr7owAXgnXAc4BRoSPGcD9zYqueqJog4lj1bY9DO/XLeowREQarN7E4c6bwI5qmy8EZoXLs4DPJmx/xB13510g24wBLRUspaUtVlUcFJeWs35nEcP6KnGISOpo6jmO/u5sCpc3A5Ud9AOB9QnlcsNtBzFjhhnzgIm15gP3A9crKmoul4I25u9l+qz3cYdhanGISApp9slxd5xg3LzGvm+mO5OA+Rm13TBdPXGMSf1B+AqKS/nFi8uZes+bvJ2Tx/FDenHysN5RhyUibZXZVMyWY5aD2c21lPkSZkswW4zZX+qrsqljVW0xY4A7m8KuqK3h9g3A4IRyg8JtTVO9hXHllU2uKg7cnVv+9jH/+HgT2Z0zmPHpI/juuaOjDktE2iqzdOC3wJkEPUDvYzYH9yUJZUYAtwCTcd+JWb/6qm1qi2MOcHm4fDnwTML2y8Krq04EdiV0aTVeYovjuOOga9cmVxUHz328mWcXbuLrU4bx4Q/PUtIQkWQ7HsjBfTXuJcBjBOeiE/0X8FvcdwLgvpV61NviMGM2cCrQx4xc4EfA7cATZkwH1gJfCos/B5wL5ABFQPOaCImJI4WvqMovKuHrf/6ARRt30T2zA98+a2TUIYlIG5ABGZjNS9g0E/eZCes1nXc+oVo1RwJg9jaQDtyK+/N17bfexOHOxbW8dHoNZR24pr46GyyxqypF5uIoK69gb2k5peVOWXkFJeUV/OzZpfx7dR6jDsnia1OGkZ6WuklQROKjFEpxn9TMajoQ3EJxKsHphTcxOxr3/LreEF8p1uL47Ws5/O6NVRQUHzyT301TR/Lfpw6PICoRaccact45F5iLeynwCWYrCBLJ+7VVGu/EkUItjm0F+7jzheVkd8ngO1NH0TkjjYwOaWSkp9G7a0c+M7Le800iIi3tfWAEZkMJEsZFwCXVyjwNXAz8EbM+BF1Xq+uqNN6JI4VaHB+sC84rPXj5JCYe3iviaEREAPcyzL4BvEBw/uIh3Bdj9hNgHu5zwtfOwmwJUA7ciHteXdXGO3EktjhW15kAI1FWXsGKLYW8uGQz97y8kvQ046hDNcqtiMSI+3MEFy4lbvthwrIDN4SPBol34khscWzcGF0cNdiyu5iv/Xk+H66rOn9017RxZGakRxiViEjyxTtxxHiIkftey+HDdfkcMzib70wdydiBPeiuOcNFpB2Id+KoPuRIjCzI3cWkw3vy1NdPjjoUEZFWFe9LlRJbHH/7W3RxVFO4r4yFufkce1h21KGIiLS6eCeOxBbHEUdEF0c1F9z7Fu5w4hEanFBE2p/USRwxuRw3v6iE1dv3cN64AZw2SvdmiEj7E+/EEcMbAHO2FgLwxQmDsJgkMxGR1hSPb+PaJLY4YpI4VoaJQ9O9ikh7Fe+rqmbNqlru3LlVd51XuI+ikvKDtn+0Pp/MjDQGZrduPCIicRHvxPGb3wTPP/sZDB3aarvN2VrAmXe/WevVwOMG9SBNI9xKO1VaWkpubi7FxcVRhxKpzMxMBg0aREatU5i2XfFOHJmZUFwM3/1uq+52zoKNuMNPLzyKzh0PPkTjB2tYEWm/cnNzycrKYsiQIe32PJ+7k5eXR25uLkNb8UdtXMQ7cRx7bDDrXyv+cd46ZzEPv7OGw3t34dKThrTafkVSRXFxcbtOGgBmRu/evdm2bVvUoUQi3omjogLSW2/sp3dytvPwO2s4Y3R/bj5Hs/SJ1KY9J41K7fkYNOtSJTOuN2OxGYvMmG1GphlDzZhrRo4Zj5vRsck7qKhotaup9pWV8/2nF3FYry785pJjGd4vq1X2KyLxtWbNGv7yl79EHUbsNPlb2YyBwLXAJHfGEoz1fhFwB3C3O8OBncD0JkdXXt5qieP+11exevsefvrZsRrhViSFuDsVzRgQtazs4Bk7Kylx1Ky538odgM5mdAC6AJuA04CnwtdnAZ9tcu2t1FW1elsh9722ivPHDWDKkX2Tvj8RaZ41a9YwcuRILrvsMsaOHcuf/vQnTjrpJCZMmMC0adMoLAzut3ruuecYNWoUEydO5Nprr+X8888H4NZbb+XSSy9l8uTJXHrppZSXl3PjjTdy3HHHMW7cOH7/+98DcPPNN/Ovf/2L8ePHc/fdd0f2eeOmyec43Nlgxl3AOmAv8CIwH8h3pzKF5wIDa3q/GTOAGcDI0tJadtIKLQ535/tPL6JThzR+eP6YpO5LpK358f8tZsnG3S1a55hDu/Oj/ziq3nIrV65k1qxZDB8+nM9//vO8/PLLdO3alTvuuINf/vKX3HTTTVx99dW8+eabDB06lIsvvviA9y9ZsoS33nqLzp07M3PmTHr06MH777/Pvn37mDx5MmeddRa33347d911F88++2yLfsZU15yuqp7AhcBQ4FCgKzC1oe93Z6Y7k4D5tV4G3QotjqcXbOCdVXncNHUk/bpnJnVfItJyDj/8cE488UTeffddlixZwuTJkxk/fjyzZs1i7dq1LFu2jCOOOGL/5bLVE8cFF1xA5/DG4hdffJFHHnmE8ePHc8IJJ5CXl8fKlStb/TOliuZcVXUG8Ik72wDM+BswGcg2o0PY6hhEMEF60yS5xZFfVMLPnl3K+MHZXHLC4Unbj0hb1ZCWQbJ07doVCHoNzjzzTGbPnn3A6wsWLGjQ+yvruPfeezn77LMPKPP666+3TLBtTHO+ldcBJ5rRxQwDTgeWAK8BXwzLXA480+Q9JLnFccfzy8jfW8r/fu5o0nUnuEhKOvHEE3n77bfJyckBYM+ePaxYsYKRI0eyevVq1qxZA8Djjz9eax1nn302999/P6Vhv/mKFSvYs2cPWVlZFBQUJP0zpJomJw535hKcBP8A+DisaybwHeAGM3KA3sCDTY4uiS2OeWt2MPu99Vw1eQhjDu2elH2ISPL17duXhx9+mIsvvphx48Zx0kknsWzZMjp37sx9993H1KlTmThxIllZWfToUfOoD1/96lcZM2YMEyZMYOzYsVx99dWUlZUxbtw40tPTOeaYY3RyPIF5xNOzmvF6t26TphQUzIMNG+Dll4NxqU45BYYNg099Ch55pEX3WVpewfm/fouC4lJeumEKXTvF+z5IkThZunQpo0ePjjqMBiksLKRbt264O9dccw0jRozg+uuvb7H6Iz0Wp57K6DfeKFzq3uo3ncXnG7OiAsaOhfz8qm3p6S3W4nB38vaUUF7h3PPySpZvKeCByyYpaYi0YQ888ACzZs2ipKSEY489lquvvjrqkNqEWHxrplEBEyZUJY1DDoEzzwzGqPrqV1tkH3+eu44fPL1o//q5Rx/CmWP6t0jdIhJP119/fYu2MCQQi8TRoaIEPloMvXrB6tVQSz9kc7y2bCuDenbm66cOo1unDlxwzKEtvg8RkfYgFoljv9/8JilJA2DDzr2MHtCdL+uyWxGRZonHfKxJ5u5s3l1Mv6xOUYciIpLy4pU4kjRM8cZdxezaW8rIQzTirYhIc8UrcSTJR+uDk+7HDMqOOBIRkdTXbhJHx/Q0Rg1Qi0OkPas+hHpdQ6pL7eJ1cjwJXVWPzl3Lsws3MfrQ7nTqoHk2RNqKRx55hLvuugszY9y4cfz0pz/lqquuYvv27fTt25c//vGPHHbYYVxxxRVkZmby4YcfMnnyZLp3786qVatYvXo1hx122EFjXEn94pU4Wpi78+M5S+iQbkz/VPubUF4k6b71LahnMMFGGz8e7rmnziKLFy/mZz/7Ge+88w59+vRhx44dXH755fsfDz30ENdeey1PP/00ALm5ubzzzjukp6dz6623HjCkujReLLqqkjW84O69ZZSUV3DDmUdylRKHSJvx6quvMm3aNPr06QNAr169+Pe//80ll1wCwKWXXspbb721v/y0adNITxgwNXFIdWm8eLU4WriralvhPgD6dNNluCJJUU/LIC4Sh1Cvab1NM5sK/Ipgeu8/4H57LeW+QDBw7XG4z6uryli0OJJluxKHSJt02mmn8eSTT5KXlwfAjh07OPnkk3nssccAePTRRznllFOiDDEezNKB3wLnAGOAizE7eKpTsyzgOmBuQ6qNSYsjOSP0bisIEke/7kocIm3JUUcdxfe+9z2mTJlCeno6xx57LPfeey9XXnkld9555/6T48LxQA7uqwEwe4xg5tYl1cr9FLgDuLEhlcYkcYRauqsqTBx91eIQaXMqT4QnevXVVw8q9/DDDx+wfuuttyYxqtgZCKxPWM8FTjighNkEYDDu/8AsBRNHC9tWuI+MdKNH59omNRcRSV0ZkIFZ4vmImbjPbHAFZmnAL4ErGrPftp04CvbRp1sn0jQtrIi0QaVQivukOopsAAYnrA8Kt1XKAsYCr4c9PocAczC7oK4T5M06OW5GthlPmbHMjKVmnGRGLzNeMmNl+NyzMRW2pG0F++irgQ1FpP16HxiB2VDMOgIXAXP2v+q+C/c+uA/BfQjwLlBn0oDmX1X1K+B5d0YBxwBLgZuBV9wZAbwSrkeissUhIi0r6imn4yAljoF7GfAN4AWC7+cncF+M2U8wu6Cp1Ta5q8qMHsCnCfvG3CkBSsy4EDg1LDYLeB34TlP30xw7i0oYc2j3KHYt0mZlZmaSl5dH7969sSSNaB137k5eXh6ZmZlRh1I/9+eA56pt+2EtZU9tSJXNOccxFNgG/NGMY4D5BNcB93dnU1hmM9Dw+Vlb8I/Q3dmxp4TeXTu2WJ0iAoMGDSI3N5dt27ZFHUqkMjMzGTRoUNRhRKI5iaMDMAH4pjtzzfgV1bql3HGzmm/SMGMGMAMYmYwBKotKytlXVkEvJQ6RFpWRkcHQoRrCpz1rzjmOXCDXff+dhk8RJJItZgwACJ+31vRmd2a6MwmY36EyfbVQi+NXL6/k+NteBnTXuIg6keRfAAAOhElEQVRIS2ty4nBnM7DejJHhptMJ7kacA1TelXM58EyzImykNdv3cPfLKxjcqwvXnT6CqWMPac3di4i0ec29j+ObwKNmdARWA1cSJKMnzJgOrAW+1Mx9NMqsf68B4PvnjeFTI/q05q5FRNqFZiUOdxYANd18cnqTKmxmV5W78/yizRzZv5uShohIkrSZO8ffztnOn99dy6ZdxXz91KOiDkdEpM2KReKo5cKrRvnJ/y1h1bZC+nTryFljdF5DRCRZYpE49mtiV1VBcSkrthZw3ekj+NYZR7ZwUCIikqhNTOT0ce4u3GH84OyoQxERafNSPnGUVziPvrcOUOIQEWkN8UocTeiqemr+ev6xcBNH9O1KdhfdJS4ikmzxShxNsGB9PgB/mn5CPSVFRKQlpHziWLRhNycd0ZuB2Z2jDkVEpF2IV+JoRFdVaXkFv3xpBR9v2MWUkX2TGJSIiCSKV+JohL9/sIFfv7KS4f26ceXkIVGHIyLSbsTrPo4GemHxZm7660I6pBn/vO4UMtJTNv+JiKSceH3jNrCraumm3QD88crjlDRERFpZSn7rfrJ9DwOzO3PKCJ3bEBFpbSmZOJZs3M3oAVlRhyEi0i7FK3E0oKuqqKSMVdsKGXNoj1YISEREqotH4vCGj467dNNuKhyOHqjEISIShXgkjkoNaHG8vDSYwlyJQ0QkGvFKHA2wYnMBGelG/+6dog5FRKRdanbiMCPdjA/NeDZcH2rGXDNyzHg8nI+8xazYWsA5YwdgzZxmVkREmqYlWhzXAUsT1u8A7nZnOLATmN7gmupJBvlFJWzKL2ZQT41LJSISlWYlDjMGAecBfwjXDTgNeCosMgv4bHP2keiO55fhwAXjD22pKkVEpJGa2+K4B7gJqAjXewP57pSF67nAwJreaMYMM+YBE8vLaipxoPlrdzD7vfVcNXkIow7p3sywRUSkqZqcOMw4H9jqzvymvN+dme5MAuZ3qBwxq46uqtv+sZQBPTI1p7iISMSaM8jhZOACM84FMoHuwK+AbDM6hK2OQcCG5gZZWl7BwtxdzPj0EXTtlJLjMoqItBlNbnG4c4s7g9wZAlwEvOrOl4HXgC+GxS4HnmlukGvziiircIb369bcqkREpJmScR/Hd4AbzMghOOfxYIPfWUtX1apthQAM66vEISIStRbp93HndeD1cHk1cHxL1Ftpf+JQi0NEJHIpcef4qq17OKR7Jt10fkNEJHLxShx1dFUN69e1lYMREWkDzKZithyzHMxuruH1GzBbgtlCzF7B7PD6qoxJ4qh9dFx3Z9XWQp3fEBFpLLN04LfAOcAY4GLMxlQr9SEwCfdxBDdv/7y+amOSOGq3rWAfBfvKlDhERBrveCAH99W4lwCPARceUML9NdyLwrV3CW6jqFO8EkcNXVWrtu0B4Ii+6qoSEUmUARmYzUt4zKhWZCCwPmG91tE8QtOBf9a339ifbd5euA+A/t0zI45ERCReSqEU90ktUpnZV4BJwJT6isYrcdTQ4sgvKgEgu0tGa0cjIpLqNgCDE9ZrHs3D7Azge8AU3PfVV2m8uqpqsGNPKQA9u7TotB4iIu3B+8AIzIZi1pFglI85B5QwOxb4PXAB7lsbUmnsE8fOohKyOnUgIz32oYqIxIt7GfAN4AWCeZOewH0xZj/B7IKw1J1AN+BJzBZgNqeW2vZLia6q7K7qphIRaRL354Dnqm37YcLyGY2tMvY/43cWlaqbSkQkRmKfOPKLSshW4hARiY14JY4auqqCFoe6qkRE4iJeiaMGO4tK1FUlIhIjsUgctU0YW1BcSkFxGb27KnGIiMRFLBLHftW6qtZsD4ZPOfKQrCiiERGRGsQkcdQ8Ou62wmIA+mV1as1gRESkDk1OHGYMNuM1M5aYsdiM68Ltvcx4yYyV4XPPpu7jk7DFMbBn56ZWISIiLaw5LY4y4NvujAFOBK4xYwxwM/CKOyOAV8L1hqnWVbUwN59DumfSL0sDHIqIxEWTE4c7m9z5IFwuILidfSDBWO+zwmKzgM82dR8f5+7i6EE9mvp2ERFJghY5x2HGEOBYYC7Q351N4Uubgf61vGeGGfOAiWVlB79euK+M1dv3MG6gEoeISJw0O3GY0Q34K/Atd3YnvuaOU8uZb3dmujMJmN+hcsSssKvK3Xn47U8AGN5PM/+JiMRJsxKHGRkESeNRd/4Wbt5ixoDw9QFAg4bpTfTnueu468UVAIxVi0NEJFaac1WVAQ8CS935ZcJLc4DLw+XLgWcaW/fzi4KerpdvmMLgXl2aGqKIiCRBc4ZVnwxcCnxsxoJw23eB24EnzJgOrAW+1OAazSguLef9NTuZ/qmh6qYSEYmhJicOd96i9tFCTm9qvf/1yDxKyir41PA+Ta1CRESSKFYTOa3bsZd/rSxkcK/OnDSsd9ThiIhIDWIy5Ejgn+G5jfsumUhmRnrE0YiISE1ilTheWbaVgdmdddOfiEiMxSpxOPBfpwyNOgwREalDrBJHmsGlJw2JOgwREalDPBKHBzeXnzy8D+lptV2oJSIicRCLxFERDkpy6sh+0QYiIiL1ikXiqBzOqnc3TdgkIhJ3sUgclaMgZnfR3OIiInEXj8QRZo6unWJ1P6KIiNQgJokjyBxmOjEuIhJ3sUgcIiKSOmKROCq7qqrPOS4iIvETj8RR8ySBIiISQ7FIHCIikjrilTjUVSUiEnvxShwiIhJ7SUscZkw1Y7kZOWbcnKz9iIhIHcymYrYcsxzMDv4uNuuE2ePh63MxG1JflUlJHGakA78FzgHGABebMaYhbxQRkRZidtB3MWbVv4unAztxHw7cDdxRX7XJanEcD+S4s9qdEuAx4MLaCh9Wtj5JYYiItGvHAzm4r8a9tu/iC4FZ4fJTwOn13Y2drMQxEEjMBrnhtv3MmGHGPGBiUVpXmD4djj46SeGIiLQ9GZCB2byEx4xqRer9Lj6gjHsZsAvoXdd+Ixscyp2ZwEwzXt/SZdgU/vCHqEIREUlJpVCK+6TW3m+yWhwbgMEJ64PCbSIi0noa8l1cVcasA9ADyKur0mQljveBEWYMNaMjcBEwJ0n7EhGRmr0PjMBsKGa1fRfPAS4Pl78IvLp/5NlaJKWryp0yM74BvACkAw+5szgZ+xIRkVq4l2F2wHcx7osx+wkwD/c5wIPAnzDLAXYQJJc6Je0chzvPAc8lq34REWkA94O/i91/mLBcDExrTJW6c1xERBpFiUNERBpFiUNERBpFiUNERBolshsAEywoLFzwaTPbG3UgMdEBKIs6iJjQsQjoOFTRsQhlQGYFdIli31bP5bqtE4TZPI/g7sc40rGoomMR0HGoomNRJcpjoa4qERFpFCUOERFplLgkjplRBxAjOhZVdCwCOg5VdCyqRHYsYnGOQ0REUkdcWhwiIpIilDhERKRRIk8cZjbVzJabWY7VNJF6ijOzh8xsq5ktStjWy8xeMrOV4XPPcLuZ2a/DY7HQzCYkvOfysPxKM7u8pn3FnZkNNrPXzGyJmS02s+vC7e3ueJhZppm9Z2Yfhcfix+H2oWY2N/zMj1swFDZm1ilczwlfH5JQ1y3h9uVmdnY0n6h5zCzdzD40s2fD9fZ6HNaY2cdmtsDM5oXb4vf/4e6RPQiG+V0FHAF0BD4CxkQZUxI+46eBCcCihG0/B24Ol28G7giXzwX+CRhwIjA33N4LWB0+9wyXe0b92ZpwLAYAE8LlLGAFMKY9Ho/wM3ULlzOAueFnfAK4KNz+O+Dr4fJ/A78Lly8CHg+Xx4T/N52AoeH/U3rUn68Jx+MG4C/As+F6ez0Oa4A+1bbF7v8j6hbH8UCOu6/22idST2nu/ibBGPeJEieHnwV8NmH7Ix54F8g2swHA2cBL7r7D3XcCLwFTkx99y3L3Te7+QbhcACwlmO+43R2P8DMVhqsZ4cOB04Cnwu3Vj0XlMXoKON3MLNz+mLvvc/dPgByC/6uUYWaDgPOAP4TrRjs8DnWI3f9H1ImjIROpt0X93X1TuLwZ6B8u13Y82txxCrsYjiX4pd0uj0fYPbMA2Erwz70KyHf3yiE1Ej/X/s8cvr4L6E3bOBb3ADcBFeF6b9rncYDgx8OLZjbfzGaE22L3/xGHsaraNXd3M2tX10SbWTfgr8C33H138IMx0J6Oh7uXA+PNLBv4OzAq4pBanZmdD2x19/lmdmrU8cTAp9x9g5n1A14ys2WJL8bl/yPqFkdDJlJvi7aETUrC563h9tqOR5s5TmaWQZA0HnX3v4Wb2+3xAHD3fOA14CSC7obKH3SJn2v/Zw5f7wHkkfrHYjJwgZmtIeiqPg34Fe3vOADg7hvC560EPyaOJ4b/H1EnjveBEeEVFLVNpN4WJU4OfznwTML2y8KrJU4EdoVN1BeAs8ysZ3hFxVnhtpQS9kU/CCx1918mvNTujoeZ9Q1bGphZZ+BMgnM+rwFfDItVPxaVx+iLwKsenAmdA1wUXm00FBgBvNc6n6L53P0Wdx/k7kMI/v9fdfcv086OA4CZdTWzrMplgr/rRcTx/yMGVxGcS3B1zSrge1HHk4TPNxvYBJQS9DVOJ+iTfQVYCbwM9ArLGvDb8Fh8DExKqOcqghN+OcCVUX+uJh6LTxH04S4EFoSPc9vj8QDGAR+Gx2IR8MNw+xEEX3g5wJNAp3B7ZrieE75+REJd3wuP0XLgnKg/WzOOyalUXVXV7o5D+Jk/Ch+LK78P4/j/oSFHRESkUaLuqhIRkRSjxCEiIo2ixCEiIo2ixCEiIo2ixCEiIo2ixCFSCzPLNrP/jjoOkbhR4hCpXTbBaKwikkCJQ6R2twPDwrkR7ow6GJG40A2AIrUIR/B91t3HRhyKSKyoxSEiIo2ixCEiIo2ixCFSuwKCKW5FJIESh0gt3D0PeNvMFunkuEgVnRwXEZFGUYtDREQaRYlDREQaRYlDREQaRYlDREQaRYlDREQaRYlDREQaRYlDREQa5f8DBI4EUaP7KjcAAAAASUVORK5CYII=\n",
260 | "text/plain": [
261 | ""
262 | ]
263 | },
264 | "metadata": {
265 | "needs_background": "light"
266 | },
267 | "output_type": "display_data"
268 | }
269 | ],
270 | "source": [
271 | "random.seed(42.0) # int(time.time())\n",
272 | "bandit = Bandit(20)\n",
273 | "agent = UCB1(bandit, alpha=0.5)\n",
274 | "agent.add_listener('regret', update_regret)\n",
275 | "agent.add_listener('corr', update_rank_corr)\n",
276 | "agent.run(5000)\n",
277 | "plot_stats('alpha=0.5')"
278 | ]
279 | },
280 | {
281 | "cell_type": "code",
282 | "execution_count": null,
283 | "metadata": {},
284 | "outputs": [],
285 | "source": []
286 | }
287 | ],
288 | "metadata": {
289 | "kernelspec": {
290 | "display_name": "Python 3",
291 | "language": "python",
292 | "name": "python3"
293 | },
294 | "language_info": {
295 | "codemirror_mode": {
296 | "name": "ipython",
297 | "version": 3
298 | },
299 | "file_extension": ".py",
300 | "mimetype": "text/x-python",
301 | "name": "python",
302 | "nbconvert_exporter": "python",
303 | "pygments_lexer": "ipython3",
304 | "version": "3.7.4"
305 | }
306 | },
307 | "nbformat": 4,
308 | "nbformat_minor": 2
309 | }
310 |
--------------------------------------------------------------------------------
/chapter10/Chatting_to_users.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "accelerator": "GPU",
6 | "colab": {
7 | "name": "Chatting_to_users.ipynb",
8 | "provenance": [],
9 | "collapsed_sections": []
10 | },
11 | "kernelspec": {
12 | "display_name": "Python 3",
13 | "language": "python",
14 | "name": "python3"
15 | },
16 | "language_info": {
17 | "codemirror_mode": {
18 | "name": "ipython",
19 | "version": 3
20 | },
21 | "file_extension": ".py",
22 | "mimetype": "text/x-python",
23 | "name": "python",
24 | "nbconvert_exporter": "python",
25 | "pygments_lexer": "ipython3",
26 | "version": "3.6.11"
27 | }
28 | },
29 | "cells": [
30 | {
31 | "cell_type": "code",
32 | "metadata": {
33 | "id": "7MP6Si4U_74a",
34 | "outputId": "fccee37f-e58e-4802-9245-c27fec732cfa",
35 | "colab": {
36 | "base_uri": "https://localhost:8080/",
37 | "height": 1000
38 | }
39 | },
40 | "source": [
41 | "!pip install git+https://www.github.com/farizrahman4u/eywa.git"
42 | ],
43 | "execution_count": 1,
44 | "outputs": [
45 | {
46 | "output_type": "stream",
47 | "text": [
48 | "Collecting git+https://www.github.com/farizrahman4u/eywa.git\n",
49 | " Cloning https://www.github.com/farizrahman4u/eywa.git to /tmp/pip-req-build-5p2u62go\n",
50 | " Running command git clone -q https://www.github.com/farizrahman4u/eywa.git /tmp/pip-req-build-5p2u62go\n",
51 | "Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from eywa==0.0.4) (1.18.5)\n",
52 | "Collecting dateparser\n",
53 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/c1/d5/5a2e51bc0058f66b54669735f739d27afc3eb453ab00520623c7ab168e22/dateparser-0.7.6-py2.py3-none-any.whl (362kB)\n",
54 | "\u001b[K |████████████████████████████████| 368kB 5.7MB/s \n",
55 | "\u001b[?25hRequirement already satisfied: requests in /usr/local/lib/python3.6/dist-packages (from eywa==0.0.4) (2.23.0)\n",
56 | "Requirement already satisfied: scipy in /usr/local/lib/python3.6/dist-packages (from eywa==0.0.4) (1.4.1)\n",
57 | "Collecting annoy\n",
58 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/a1/5b/1c22129f608b3f438713b91cd880dc681d747a860afe3e8e0af86e921942/annoy-1.17.0.tar.gz (646kB)\n",
59 | "\u001b[K |████████████████████████████████| 655kB 21.7MB/s \n",
60 | "\u001b[?25hCollecting responder\n",
61 | " Downloading https://files.pythonhosted.org/packages/36/b9/99831331a9d22f79682f31d75b8be454c3e2781c2cca8d14e5882acf968a/responder-2.0.5-py3-none-any.whl\n",
62 | "Requirement already satisfied: python-dateutil in /usr/local/lib/python3.6/dist-packages (from dateparser->eywa==0.0.4) (2.8.1)\n",
63 | "Requirement already satisfied: pytz in /usr/local/lib/python3.6/dist-packages (from dateparser->eywa==0.0.4) (2018.9)\n",
64 | "Requirement already satisfied: tzlocal in /usr/local/lib/python3.6/dist-packages (from dateparser->eywa==0.0.4) (1.5.1)\n",
65 | "Requirement already satisfied: regex!=2019.02.19 in /usr/local/lib/python3.6/dist-packages (from dateparser->eywa==0.0.4) (2019.12.20)\n",
66 | "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests->eywa==0.0.4) (2.10)\n",
67 | "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests->eywa==0.0.4) (2020.6.20)\n",
68 | "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests->eywa==0.0.4) (1.24.3)\n",
69 | "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests->eywa==0.0.4) (3.0.4)\n",
70 | "Requirement already satisfied: docopt in /usr/local/lib/python3.6/dist-packages (from responder->eywa==0.0.4) (0.6.2)\n",
71 | "Collecting starlette==0.12.*\n",
72 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/a4/bd/b36fae5877bd5eca9dc72434273c11c9ba8a47b8fdfe9159ebc355d26500/starlette-0.12.13.tar.gz (47kB)\n",
73 | "\u001b[K |████████████████████████████████| 51kB 6.1MB/s \n",
74 | "\u001b[?25hCollecting whitenoise\n",
75 | " Downloading https://files.pythonhosted.org/packages/50/83/5d91949e370e52578a99ef6391c3b3e19f9fd1f5b4f58d5cbd6e2862d4a8/whitenoise-5.2.0-py2.py3-none-any.whl\n",
76 | "Requirement already satisfied: pyyaml in /usr/local/lib/python3.6/dist-packages (from responder->eywa==0.0.4) (3.13)\n",
77 | "Collecting graphql-server-core>=1.1\n",
78 | " Downloading https://files.pythonhosted.org/packages/2d/c4/911e0c61640a84b6f4929c854c6a16701a61bfc87e9af02ef17de4d699d7/graphql-server-core-2.0.0.tar.gz\n",
79 | "Collecting marshmallow\n",
80 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/6a/f8/516495245005449e9493972571d1d83c37e01bd853840887617e671a30ea/marshmallow-3.8.0-py2.py3-none-any.whl (46kB)\n",
81 | "\u001b[K |████████████████████████████████| 51kB 7.7MB/s \n",
82 | "\u001b[?25hCollecting requests-toolbelt\n",
83 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/60/ef/7681134338fc097acef8d9b2f8abe0458e4d87559c689a8c306d0957ece5/requests_toolbelt-0.9.1-py2.py3-none-any.whl (54kB)\n",
84 | "\u001b[K |████████████████████████████████| 61kB 8.3MB/s \n",
85 | "\u001b[?25hCollecting rfc3986\n",
86 | " Downloading https://files.pythonhosted.org/packages/78/be/7b8b99fd74ff5684225f50dd0e865393d2265656ef3b4ba9eaaaffe622b8/rfc3986-1.4.0-py2.py3-none-any.whl\n",
87 | "Collecting apispec>=1.0.0b1\n",
88 | " Downloading https://files.pythonhosted.org/packages/17/57/45bfcbe3c406597164983b8d383f44716aafcd15dd79eba3a0355fbb1b24/apispec-4.0.0-py2.py3-none-any.whl\n",
89 | "Collecting aiofiles\n",
90 | " Downloading https://files.pythonhosted.org/packages/f4/2b/078a9771ae4b67e36b0c2a973df845260833a4eb088b81c84b738509b4c4/aiofiles-0.5.0-py3-none-any.whl\n",
91 | "Requirement already satisfied: jinja2 in /usr/local/lib/python3.6/dist-packages (from responder->eywa==0.0.4) (2.11.2)\n",
92 | "Collecting apistar\n",
93 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/d6/0c/3066b856f661bc58b16c1a2ff4eeed25dc4f0d4618871b3454066dad89e0/apistar-0.7.2.tar.gz (3.3MB)\n",
94 | "\u001b[K |████████████████████████████████| 3.3MB 32.8MB/s \n",
95 | "\u001b[?25hCollecting uvicorn==0.10.*\n",
96 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/45/d6/8d38b80e1f1d30ef0c2a44b1a6a9ea257cd9afba6abb32100e8cf3783638/uvicorn-0.10.9-py3-none-any.whl (42kB)\n",
97 | "\u001b[K |████████████████████████████████| 51kB 8.2MB/s \n",
98 | "\u001b[?25hRequirement already satisfied: itsdangerous in /usr/local/lib/python3.6/dist-packages (from responder->eywa==0.0.4) (1.1.0)\n",
99 | "Collecting python-multipart\n",
100 | " Downloading https://files.pythonhosted.org/packages/46/40/a933ac570bf7aad12a298fc53458115cc74053474a72fbb8201d7dc06d3d/python-multipart-0.0.5.tar.gz\n",
101 | "Collecting graphene<3.0\n",
102 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/05/97/45e743b372f65a619f8d1eb2897efb74fb1b0ffddc731ad37e0aa187ec5c/graphene-2.1.8-py2.py3-none-any.whl (107kB)\n",
103 | "\u001b[K |████████████████████████████████| 112kB 29.3MB/s \n",
104 | "\u001b[?25hRequirement already satisfied: six>=1.5 in /usr/local/lib/python3.6/dist-packages (from python-dateutil->dateparser->eywa==0.0.4) (1.15.0)\n",
105 | "Collecting graphql-core<3,>=2.3\n",
106 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/11/71/d51beba3d8986fa6d8670ec7bcba989ad6e852d5ae99d95633e5dacc53e7/graphql_core-2.3.2-py2.py3-none-any.whl (252kB)\n",
107 | "\u001b[K |████████████████████████████████| 256kB 42.0MB/s \n",
108 | "\u001b[?25hRequirement already satisfied: promise<3,>=2.3 in /usr/local/lib/python3.6/dist-packages (from graphql-server-core>=1.1->responder->eywa==0.0.4) (2.3)\n",
109 | "Requirement already satisfied: MarkupSafe>=0.23 in /usr/local/lib/python3.6/dist-packages (from jinja2->responder->eywa==0.0.4) (1.1.1)\n",
110 | "Requirement already satisfied: click in /usr/local/lib/python3.6/dist-packages (from apistar->responder->eywa==0.0.4) (7.1.2)\n",
111 | "Collecting typesystem\n",
112 | " Downloading https://files.pythonhosted.org/packages/bd/be/158f4a02dde348e16cec7ab9603df521d9175d37ff7d048f086d9b76eb7b/typesystem-0.2.4.tar.gz\n",
113 | "Collecting h11==0.9.*\n",
114 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/5a/fd/3dad730b0f95e78aeeb742f96fa7bbecbdd56a58e405d3da440d5bfb90c6/h11-0.9.0-py2.py3-none-any.whl (53kB)\n",
115 | "\u001b[K |████████████████████████████████| 61kB 9.4MB/s \n",
116 | "\u001b[?25hCollecting uvloop>=0.14.0; sys_platform != \"win32\" and sys_platform != \"cygwin\" and platform_python_implementation != \"pypy\"\n",
117 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/41/48/586225bbb02d3bdca475b17e4be5ce5b3f09da2d6979f359916c1592a687/uvloop-0.14.0-cp36-cp36m-manylinux2010_x86_64.whl (3.9MB)\n",
118 | "\u001b[K |████████████████████████████████| 3.9MB 55.2MB/s \n",
119 | "\u001b[?25hCollecting httptools==0.0.13; sys_platform != \"win32\" and sys_platform != \"cygwin\" and platform_python_implementation != \"pypy\"\n",
120 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/1b/03/215969db11abe8741e9c266a4cbe803a372bd86dd35fa0084c4df6d4bd00/httptools-0.0.13.tar.gz (104kB)\n",
121 | "\u001b[K |████████████████████████████████| 112kB 61.1MB/s \n",
122 | "\u001b[?25hCollecting websockets==8.*\n",
123 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/bb/d9/856af84843912e2853b1b6e898ac8b802989fcf9ecf8e8445a1da263bf3b/websockets-8.1-cp36-cp36m-manylinux2010_x86_64.whl (78kB)\n",
124 | "\u001b[K |████████████████████████████████| 81kB 12.4MB/s \n",
125 | "\u001b[?25hCollecting aniso8601<=7,>=3\n",
126 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/45/a4/b4fcadbdab46c2ec2d2f6f8b4ab3f64fd0040789ac7f065eba82119cd602/aniso8601-7.0.0-py2.py3-none-any.whl (42kB)\n",
127 | "\u001b[K |████████████████████████████████| 51kB 9.9MB/s \n",
128 | "\u001b[?25hCollecting graphql-relay<3,>=2\n",
129 | " Downloading https://files.pythonhosted.org/packages/94/48/6022ea2e89cb936c3b933a0409c6e29bf8a68c050fe87d97f98aff6e5e9e/graphql_relay-2.0.1-py3-none-any.whl\n",
130 | "Collecting rx<2,>=1.6\n",
131 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/33/0f/5ef4ac78e2a538cc1b054eb86285fe0bf7a5dbaeaac2c584757c300515e2/Rx-1.6.1-py2.py3-none-any.whl (179kB)\n",
132 | "\u001b[K |████████████████████████████████| 184kB 59.2MB/s \n",
133 | "\u001b[?25hBuilding wheels for collected packages: eywa, annoy, starlette, graphql-server-core, apistar, python-multipart, typesystem, httptools\n",
134 | " Building wheel for eywa (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
135 | " Created wheel for eywa: filename=eywa-0.0.4-cp36-none-any.whl size=152474 sha256=5e7e5536e3d5706832e55681051ad93aad4d636d0ac6ab84957c0eb208202bf2\n",
136 | " Stored in directory: /tmp/pip-ephem-wheel-cache-25lny_ly/wheels/e3/d3/11/d2c9b9b41cb0f2b8397d58387cc4c9ce92956c5a16fcb08c02\n",
137 | " Building wheel for annoy (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
138 | " Created wheel for annoy: filename=annoy-1.17.0-cp36-cp36m-linux_x86_64.whl size=390355 sha256=9d307bfa42def189c87ba081291de398465cd20b4a9b8bd9ac3fb8dbf2648207\n",
139 | " Stored in directory: /root/.cache/pip/wheels/3a/c5/59/cce7e67b52c8e987389e53f917b6bb2a9d904a03246fadcb1e\n",
140 | " Building wheel for starlette (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
141 | " Created wheel for starlette: filename=starlette-0.12.13-cp36-none-any.whl size=58383 sha256=67d11e9e0501f5c7991392c63c200a1734746e7e60b5bdaad9ad3fe47e231750\n",
142 | " Stored in directory: /root/.cache/pip/wheels/08/be/e8/2a06d0515b4730f45415eb97bf1b9ebbc4a22c67b4f5113885\n",
143 | " Building wheel for graphql-server-core (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
144 | " Created wheel for graphql-server-core: filename=graphql_server_core-2.0.0-py2.py3-none-any.whl size=7728 sha256=c2c3863265cf0838f1a8c31242fabb2e77b9a9a77511497b6e8f41408cb19e59\n",
145 | " Stored in directory: /root/.cache/pip/wheels/0c/81/82/c8d678001af54abb231abc5c521cb86ad7c7176bab5aff8673\n",
146 | " Building wheel for apistar (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
147 | " Created wheel for apistar: filename=apistar-0.7.2-cp36-none-any.whl size=3312662 sha256=cd0d091d14110dbf9c02ee294332de9f9b1075f4ef37817819a97d440541c911\n",
148 | " Stored in directory: /root/.cache/pip/wheels/32/b5/a5/49563b6328da6f18de11d34c9167b5ff036808b9bf2e2941b3\n",
149 | " Building wheel for python-multipart (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
150 | " Created wheel for python-multipart: filename=python_multipart-0.0.5-cp36-none-any.whl size=31671 sha256=1b2f621ed4ce76812c11e62712c1fde6a82645f5c38c684fa8671e0ae6abb7d1\n",
151 | " Stored in directory: /root/.cache/pip/wheels/f0/e6/66/14a866a3cbd6a0cabfbef91f7edf40aa03595ef6c88d6d1be4\n",
152 | " Building wheel for typesystem (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
153 | " Created wheel for typesystem: filename=typesystem-0.2.4-cp36-none-any.whl size=26039 sha256=0d345b5a7f935078b687530190d6ca238e56a3deb64050c4e2c6305f0e264ff8\n",
154 | " Stored in directory: /root/.cache/pip/wheels/25/05/bf/d42f7a013cc83b042eb546b5749108d8a664bac99f44dacc9e\n",
155 | " Building wheel for httptools (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
156 | " Created wheel for httptools: filename=httptools-0.0.13-cp36-cp36m-linux_x86_64.whl size=212529 sha256=0234bf32b54033eb1776e7c977d6a1d5491a1ec941c719532c926c7bfe372856\n",
157 | " Stored in directory: /root/.cache/pip/wheels/e8/3e/2e/013f99b42efc25cf3589730cf380738e46b1e5edaf2f78d525\n",
158 | "Successfully built eywa annoy starlette graphql-server-core apistar python-multipart typesystem httptools\n",
159 | "Installing collected packages: dateparser, annoy, starlette, whitenoise, rx, graphql-core, graphql-server-core, marshmallow, requests-toolbelt, rfc3986, apispec, aiofiles, typesystem, apistar, h11, uvloop, httptools, websockets, uvicorn, python-multipart, aniso8601, graphql-relay, graphene, responder, eywa\n",
160 | "Successfully installed aiofiles-0.5.0 aniso8601-7.0.0 annoy-1.17.0 apispec-4.0.0 apistar-0.7.2 dateparser-0.7.6 eywa-0.0.4 graphene-2.1.8 graphql-core-2.3.2 graphql-relay-2.0.1 graphql-server-core-2.0.0 h11-0.9.0 httptools-0.0.13 marshmallow-3.8.0 python-multipart-0.0.5 requests-toolbelt-0.9.1 responder-2.0.5 rfc3986-1.4.0 rx-1.6.1 starlette-0.12.13 typesystem-0.2.4 uvicorn-0.10.9 uvloop-0.14.0 websockets-8.1 whitenoise-5.2.0\n"
161 | ],
162 | "name": "stdout"
163 | }
164 | ]
165 | },
166 | {
167 | "cell_type": "code",
168 | "metadata": {
169 | "id": "g0rnWtvyJPK6",
170 | "outputId": "318b9ad3-37f5-4692-8438-f5f8c5466ab8",
171 | "colab": {
172 | "base_uri": "https://localhost:8080/",
173 | "height": 241
174 | }
175 | },
176 | "source": [
177 | "!pip install pyowm"
178 | ],
179 | "execution_count": 2,
180 | "outputs": [
181 | {
182 | "output_type": "stream",
183 | "text": [
184 | "Collecting pyowm\n",
185 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/41/2a/83e26bc87763d0d34767ddc5c875608d4a0a0da66e59730a15c55aec6eff/pyowm-2.10.0-py3-none-any.whl (3.7MB)\n",
186 | "\u001b[K |████████████████████████████████| 3.8MB 4.4MB/s \n",
187 | "\u001b[?25hRequirement already satisfied: requests<3,>=2.20.0 in /usr/local/lib/python3.6/dist-packages (from pyowm) (2.23.0)\n",
188 | "Collecting geojson<3,>=2.3.0\n",
189 | " Downloading https://files.pythonhosted.org/packages/e4/8d/9e28e9af95739e6d2d2f8d4bef0b3432da40b7c3588fbad4298c1be09e48/geojson-2.5.0-py2.py3-none-any.whl\n",
190 | "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.20.0->pyowm) (2.10)\n",
191 | "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.20.0->pyowm) (2020.6.20)\n",
192 | "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.20.0->pyowm) (1.24.3)\n",
193 | "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.20.0->pyowm) (3.0.4)\n",
194 | "Installing collected packages: geojson, pyowm\n",
195 | "Successfully installed geojson-2.5.0 pyowm-2.10.0\n"
196 | ],
197 | "name": "stdout"
198 | }
199 | ]
200 | },
201 | {
202 | "cell_type": "code",
203 | "metadata": {
204 | "id": "U37DdYnTa5DO",
205 | "outputId": "2ffe4910-d3f4-4702-fe82-1acd1cd7a168",
206 | "colab": {
207 | "base_uri": "https://localhost:8080/",
208 | "height": 340
209 | }
210 | },
211 | "source": [
212 | "from eywa.nlu import Classifier\n",
213 | "\n",
214 | "CONV_SAMPLES = {\n",
215 | " 'greetings' : ['Hi', 'hello', 'How are you', 'hey there', 'hey'],\n",
216 | " 'taxi' : ['book a cab', 'need a ride', 'find me a cab'],\n",
217 | " 'weather' : ['what is the weather in tokyo', 'weather germany',\n",
218 | " 'what is the weather like in kochi',\n",
219 | " 'what is the weather like', 'is it hot outside'],\n",
220 | " 'datetime' : ['what day is today', 'todays date', 'what time is it now',\n",
221 | " 'time now', 'what is the time'],\n",
222 | " 'music' : ['play the Beatles', 'shuffle songs', 'make a sound']\n",
223 | "}\n",
224 | "\n",
225 | "CLF = Classifier()\n",
226 | "for key in CONV_SAMPLES:\n",
227 | " CLF.fit(CONV_SAMPLES[key], key)\n",
228 | "\n",
229 | "print(CLF.predict('will it rain today')) # >>> 'weather'\n",
230 | "print(CLF.predict('play playlist rock n\\'roll')) # >>> 'music'\n",
231 | "print(CLF.predict('what\\'s the hour?')) # >>> 'datetime'"
232 | ],
233 | "execution_count": 3,
234 | "outputs": [
235 | {
236 | "output_type": "stream",
237 | "text": [
238 | "Downloading embeddings...\n",
239 | "Source: https://github.com/explosion/sense2vec/releases/download/v1.0.0a0/reddit_vectors-1.1.0.tar.gz\n",
240 | "Destination /root/.eywa/lang/en/embeddings/reddit_vectors-1.1.0.tar.gz\n",
241 | "Size: 560.460747718811MB\n",
242 | "586899456/587685689 [============================>.] - 99.866215% - ETA: 0sDone.\n",
243 | "Extracting...\n",
244 | "Done.\n",
245 | "Converting embeddings...\n",
246 | "Done.\n",
247 | "Seems you are running the program for the first time. Building index...\n",
248 | "Building tree...\n",
249 | "1193838/1195260 [============================>.] - 99.881030% - ETA: 0sDone.\n",
250 | "Creating databases...\n",
251 | "1194090/1195260 [============================>.] - 99.902113% - ETA: 0s\n",
252 | "Converting words to indices...\n",
253 | "1191011/1195260 [============================>.] - 99.644512% - ETA: 0sDone.\n",
254 | "datetime\n",
255 | "music\n",
256 | "weather\n"
257 | ],
258 | "name": "stdout"
259 | }
260 | ]
261 | },
262 | {
263 | "cell_type": "code",
264 | "metadata": {
265 | "id": "KBS524RIGk_h"
266 | },
267 | "source": [
268 | "from eywa.nlu import EntityExtractor\n",
269 | "\n",
270 | "X_WEATHER = ['what is the weather in tokyo', 'weather germany', 'what is the weather like in kochi']\n",
271 | "Y_WEATHER = [{'intent': 'weather', 'place': 'tokyo'}, {'intent': 'weather', 'place': 'germany'},\n",
272 | " {'intent': 'weather', 'place': 'kochi'}]\n",
273 | "\n",
274 | "EX_WEATHER = EntityExtractor()\n",
275 | "EX_WEATHER.fit(X_WEATHER, Y_WEATHER)\n"
276 | ],
277 | "execution_count": 4,
278 | "outputs": []
279 | },
280 | {
281 | "cell_type": "code",
282 | "metadata": {
283 | "id": "Cj70xxYjGsr9",
284 | "outputId": "e302580e-880c-4f52-d609-1960d154c8fb",
285 | "colab": {
286 | "base_uri": "https://localhost:8080/",
287 | "height": 156
288 | }
289 | },
290 | "source": [
291 | "EX_WEATHER.predict('what is the weather in London')"
292 | ],
293 | "execution_count": 5,
294 | "outputs": [
295 | {
296 | "output_type": "stream",
297 | "text": [
298 | "WARNING:tensorflow:Layer gru will not use cuDNN kernel since it doesn't meet the cuDNN kernel criteria. It will use generic GPU kernel as fallback when running on GPU\n",
299 | "WARNING:tensorflow:Layer gru will not use cuDNN kernel since it doesn't meet the cuDNN kernel criteria. It will use generic GPU kernel as fallback when running on GPU\n",
300 | "WARNING:tensorflow:Layer gru will not use cuDNN kernel since it doesn't meet the cuDNN kernel criteria. It will use generic GPU kernel as fallback when running on GPU\n",
301 | "WARNING:tensorflow:Layer gru_1 will not use cuDNN kernel since it doesn't meet the cuDNN kernel criteria. It will use generic GPU kernel as fallback when running on GPU\n",
302 | "WARNING:tensorflow:Layer gru_1 will not use cuDNN kernel since it doesn't meet the cuDNN kernel criteria. It will use generic GPU kernel as fallback when running on GPU\n",
303 | "WARNING:tensorflow:Layer gru_1 will not use cuDNN kernel since it doesn't meet the cuDNN kernel criteria. It will use generic GPU kernel as fallback when running on GPU\n"
304 | ],
305 | "name": "stdout"
306 | },
307 | {
308 | "output_type": "execute_result",
309 | "data": {
310 | "text/plain": [
311 | "{'intent': 'weather', 'place': 'London'}"
312 | ]
313 | },
314 | "metadata": {
315 | "tags": []
316 | },
317 | "execution_count": 5
318 | }
319 | ]
320 | },
321 | {
322 | "cell_type": "code",
323 | "metadata": {
324 | "id": "OKm0N9yAaosB",
325 | "outputId": "882f4bd4-cd0a-4383-c12a-2b1f8940baf2",
326 | "colab": {
327 | "base_uri": "https://localhost:8080/",
328 | "height": 156
329 | }
330 | },
331 | "source": [
332 | "from eywa.nlu import EntityExtractor\n",
333 | "\n",
334 | "x = ['what is the weather in tokyo', 'what is the weather', 'what is the weather like in kochi']\n",
335 | "y = [{'intent': 'weather', 'place': 'tokyo'}, {'intent': 'weather', 'place': 'here'}, {'intent': 'weather', 'place': 'kochi'}]\n",
336 | "\n",
337 | "ex = EntityExtractor()\n",
338 | "ex.fit(x, y)\n",
339 | "\n",
340 | "x_test = 'what is the weather in london like'\n",
341 | "print(ex.predict(x_test))\n"
342 | ],
343 | "execution_count": 6,
344 | "outputs": [
345 | {
346 | "output_type": "stream",
347 | "text": [
348 | "WARNING:tensorflow:Layer gru_2 will not use cuDNN kernel since it doesn't meet the cuDNN kernel criteria. It will use generic GPU kernel as fallback when running on GPU\n",
349 | "WARNING:tensorflow:Layer gru_2 will not use cuDNN kernel since it doesn't meet the cuDNN kernel criteria. It will use generic GPU kernel as fallback when running on GPU\n",
350 | "WARNING:tensorflow:Layer gru_2 will not use cuDNN kernel since it doesn't meet the cuDNN kernel criteria. It will use generic GPU kernel as fallback when running on GPU\n",
351 | "WARNING:tensorflow:Layer gru_3 will not use cuDNN kernel since it doesn't meet the cuDNN kernel criteria. It will use generic GPU kernel as fallback when running on GPU\n",
352 | "WARNING:tensorflow:Layer gru_3 will not use cuDNN kernel since it doesn't meet the cuDNN kernel criteria. It will use generic GPU kernel as fallback when running on GPU\n",
353 | "WARNING:tensorflow:Layer gru_3 will not use cuDNN kernel since it doesn't meet the cuDNN kernel criteria. It will use generic GPU kernel as fallback when running on GPU\n",
354 | "{'intent': 'weather', 'place': 'like'}\n"
355 | ],
356 | "name": "stdout"
357 | }
358 | ]
359 | },
360 | {
361 | "cell_type": "code",
362 | "metadata": {
363 | "id": "hMd46fqHKrKh",
364 | "outputId": "6c240115-190e-4271-e063-fa561361a548",
365 | "colab": {
366 | "base_uri": "https://localhost:8080/",
367 | "height": 51
368 | }
369 | },
370 | "source": [
371 | "from pyowm import OWM\n",
372 | "import logging\n",
373 | "\n",
374 | "LOGGER = logging.getLogger('main')\n",
375 | "# put your API key here:\n",
376 | "mgr = OWM('API-key') \n",
377 | "# in older versions you had to load OWM().weather_manager() instead\n",
378 | "\n",
379 | "def get_weather_forecast(place):\n",
380 | " LOGGER.warning(place)\n",
381 | " observation = mgr.weather_at_place(place)\n",
382 | " return observation.get_weather().get_detailed_status()\n",
383 | "\n",
384 | "print(get_weather_forecast('London'))"
385 | ],
386 | "execution_count": 18,
387 | "outputs": [
388 | {
389 | "output_type": "stream",
390 | "text": [
391 | "London\n"
392 | ],
393 | "name": "stderr"
394 | },
395 | {
396 | "output_type": "stream",
397 | "text": [
398 | "broken clouds\n"
399 | ],
400 | "name": "stdout"
401 | }
402 | ]
403 | },
404 | {
405 | "cell_type": "code",
406 | "metadata": {
407 | "id": "m4e8gz-aMwY2"
408 | },
409 | "source": [
410 | "X_GREETING = ['Hii', 'helllo', 'Howdy', 'hey there', 'hey', 'Hi']\n",
411 | "Y_GREETING = [{'greet': 'Hii'}, {'greet': 'helllo'}, {'greet': 'Howdy'},\n",
412 | " {'greet': 'hey'}, {'greet': 'hey'}, {'greet': 'Hi'}]\n",
413 | "\n",
414 | "EX_GREETING = EntityExtractor()\n",
415 | "EX_GREETING.fit(X_GREETING, Y_GREETING)"
416 | ],
417 | "execution_count": 8,
418 | "outputs": []
419 | },
420 | {
421 | "cell_type": "code",
422 | "metadata": {
423 | "id": "RjAweeTDMyv7"
424 | },
425 | "source": [
426 | "X_DATETIME = ['what day is today', 'date today', 'what time is it now', 'time now']\n",
427 | "Y_DATETIME = [{'intent' : 'day', 'target': 'today'}, {'intent' : 'date', 'target': 'today'},\n",
428 | " {'intent' : 'time', 'target': 'now'}, {'intent' : 'time', 'target': 'now'}]\n",
429 | "\n",
430 | "EX_DATETIME = EntityExtractor()\n",
431 | "EX_DATETIME.fit(X_DATETIME, Y_DATETIME)"
432 | ],
433 | "execution_count": 9,
434 | "outputs": []
435 | },
436 | {
437 | "cell_type": "code",
438 | "metadata": {
439 | "id": "r9wyTjwyNDGE"
440 | },
441 | "source": [
442 | "import datetime\n",
443 | "\n",
444 | "_EXTRACTORS = {\n",
445 | " 'taxi': None,\n",
446 | " 'weather': EX_WEATHER,\n",
447 | " 'greetings': EX_GREETING,\n",
448 | " 'datetime': EX_DATETIME,\n",
449 | " 'music': None\n",
450 | "}\n",
451 | "\n",
452 | "def question_and_answer(u_query):\n",
453 | " '''Answer a user query\n",
454 | " '''\n",
455 | " q_class = CLF.predict(u_query)\n",
456 | " if _EXTRACTORS[q_class] is None:\n",
457 | " return 'Sorry, you have to upgrade your software!'\n",
458 | "\n",
459 | " q_entities = _EXTRACTORS[q_class].predict(u_query)\n",
460 | " if q_class == 'greetings':\n",
461 | " return q_entities.get('greet', 'hello')\n",
462 | " \n",
463 | " if q_class == 'weather':\n",
464 | " place = q_entities.get('place', 'London').replace('_', ' ')\n",
465 | " return 'The forecast for {} is {}'.format(\n",
466 | " place,\n",
467 | " get_weather_forecast(place)\n",
468 | " )\n",
469 | "\n",
470 | " if q_class == 'datetime':\n",
471 | " return 'Today\\'s date is {}'.format(\n",
472 | " datetime.datetime.today().strftime('%B %d, %Y')\n",
473 | " )\n",
474 | " \n",
475 | " return 'I couldn\\'t understand what you said. I am sorry.'"
476 | ],
477 | "execution_count": 10,
478 | "outputs": []
479 | },
480 | {
481 | "cell_type": "code",
482 | "metadata": {
483 | "id": "FFK72xR7bZ6M",
484 | "outputId": "c1d51358-e25a-42ba-95c2-92f966a87b92",
485 | "colab": {
486 | "base_uri": "https://localhost:8080/",
487 | "height": 51
488 | }
489 | },
490 | "source": [
491 | "while True:\n",
492 | " query = input('\\nHow can I help you? ')\n",
493 | " print(question_and_answer(query))"
494 | ],
495 | "execution_count": null,
496 | "outputs": [
497 | {
498 | "output_type": "stream",
499 | "text": [
500 | "germany\n"
501 | ],
502 | "name": "stderr"
503 | },
504 | {
505 | "output_type": "stream",
506 | "text": [
507 | "The forecast for germany is light rain\n"
508 | ],
509 | "name": "stdout"
510 | }
511 | ]
512 | },
513 | {
514 | "cell_type": "code",
515 | "metadata": {
516 | "id": "mzzfkiINolVv"
517 | },
518 | "source": [
519 | ""
520 | ],
521 | "execution_count": null,
522 | "outputs": []
523 | }
524 | ]
525 | }
--------------------------------------------------------------------------------
/chapter10/Classifying_newsgroups.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "id": "A9Dy9CPcPab4"
7 | },
8 | "source": [
9 | "This recipe follows initially the scikit-learn tutorial on working with text data: https://scikit-learn.org/stable/tutorial/text_analytics/working_with_text_data.html"
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": 1,
15 | "metadata": {
16 | "id": "rVStCLxetO_B"
17 | },
18 | "outputs": [],
19 | "source": [
20 | "from sklearn.datasets import fetch_20newsgroups"
21 | ]
22 | },
23 | {
24 | "cell_type": "code",
25 | "execution_count": 2,
26 | "metadata": {
27 | "colab": {
28 | "base_uri": "https://localhost:8080/",
29 | "height": 51
30 | },
31 | "id": "VJCtjWYKPC5U",
32 | "outputId": "b371edfe-fde5-4a04-8b8a-fd444d921b08"
33 | },
34 | "outputs": [
35 | {
36 | "name": "stderr",
37 | "output_type": "stream",
38 | "text": [
39 | "Downloading 20news dataset. This may take a few minutes.\n",
40 | "Downloading dataset from https://ndownloader.figshare.com/files/5975967 (14 MB)\n"
41 | ]
42 | }
43 | ],
44 | "source": [
45 | "categories = ['alt.atheism', 'soc.religion.christian', 'comp.graphics', 'sci.med']\n",
46 | "twenty_train = fetch_20newsgroups(\n",
47 | " subset='train',\n",
48 | " categories=categories,\n",
49 | " shuffle=True,\n",
50 | " random_state=42\n",
51 | " )"
52 | ]
53 | },
54 | {
55 | "cell_type": "markdown",
56 | "metadata": {
57 | "id": "Lex3NsBSPSdU"
58 | },
59 | "source": [
60 | "It's a small dataset"
61 | ]
62 | },
63 | {
64 | "cell_type": "code",
65 | "execution_count": 3,
66 | "metadata": {
67 | "colab": {
68 | "base_uri": "https://localhost:8080/",
69 | "height": 34
70 | },
71 | "id": "mB5zYqo6PE9k",
72 | "outputId": "9cdf1f5b-7520-4ec7-dde7-dc756d4dd482"
73 | },
74 | "outputs": [
75 | {
76 | "data": {
77 | "text/plain": [
78 | "2257"
79 | ]
80 | },
81 | "execution_count": 3,
82 | "metadata": {
83 | "tags": []
84 | },
85 | "output_type": "execute_result"
86 | }
87 | ],
88 | "source": [
89 | "len(twenty_train.filenames)"
90 | ]
91 | },
92 | {
93 | "cell_type": "code",
94 | "execution_count": 4,
95 | "metadata": {
96 | "colab": {
97 | "base_uri": "https://localhost:8080/",
98 | "height": 68
99 | },
100 | "id": "WedME_XCPRKB",
101 | "outputId": "7b419e25-a3f9-472e-ab93-ef5a61a5c7ae"
102 | },
103 | "outputs": [
104 | {
105 | "name": "stdout",
106 | "output_type": "stream",
107 | "text": [
108 | "From: sd345@city.ac.uk (Michael Collier)\n",
109 | "Subject: Converting images to HP LaserJet III?\n",
110 | "Nntp-Posting-Host: hampton\n"
111 | ]
112 | }
113 | ],
114 | "source": [
115 | "print(\"\\n\".join(twenty_train.data[0].split(\"\\n\")[:3]))"
116 | ]
117 | },
118 | {
119 | "cell_type": "markdown",
120 | "metadata": {
121 | "id": "mALQhYvMSbrZ"
122 | },
123 | "source": [
124 | "# Using a bag-of-words approach with a classifier"
125 | ]
126 | },
127 | {
128 | "cell_type": "code",
129 | "execution_count": 5,
130 | "metadata": {
131 | "id": "Jlp8DsbfRAYY"
132 | },
133 | "outputs": [],
134 | "source": [
135 | "from sklearn.pipeline import Pipeline\n",
136 | "from sklearn.feature_extraction.text import CountVectorizer\n",
137 | "from sklearn.feature_extraction.text import TfidfTransformer\n",
138 | "from sklearn.ensemble import RandomForestClassifier\n",
139 | "\n",
140 | "text_clf = Pipeline([\n",
141 | " ('vect', CountVectorizer()),\n",
142 | " ('tfidf', TfidfTransformer()),\n",
143 | " ('clf', RandomForestClassifier()),\n",
144 | "])"
145 | ]
146 | },
147 | {
148 | "cell_type": "code",
149 | "execution_count": 6,
150 | "metadata": {
151 | "colab": {
152 | "base_uri": "https://localhost:8080/",
153 | "height": 408
154 | },
155 | "id": "DcDpL_6EQYAM",
156 | "outputId": "54df2e19-41bd-405b-9fb7-3aa1230fbdfc"
157 | },
158 | "outputs": [
159 | {
160 | "data": {
161 | "text/plain": [
162 | "Pipeline(memory=None,\n",
163 | " steps=[('vect',\n",
164 | " CountVectorizer(analyzer='word', binary=False,\n",
165 | " decode_error='strict',\n",
166 | " dtype=, encoding='utf-8',\n",
167 | " input='content', lowercase=True, max_df=1.0,\n",
168 | " max_features=None, min_df=1,\n",
169 | " ngram_range=(1, 1), preprocessor=None,\n",
170 | " stop_words=None, strip_accents=None,\n",
171 | " token_pattern='(?u)\\\\b\\\\w\\\\w+\\\\b',\n",
172 | " tokenizer=None, vocabulary=Non...\n",
173 | " RandomForestClassifier(bootstrap=True, ccp_alpha=0.0,\n",
174 | " class_weight=None, criterion='gini',\n",
175 | " max_depth=None, max_features='auto',\n",
176 | " max_leaf_nodes=None, max_samples=None,\n",
177 | " min_impurity_decrease=0.0,\n",
178 | " min_impurity_split=None,\n",
179 | " min_samples_leaf=1, min_samples_split=2,\n",
180 | " min_weight_fraction_leaf=0.0,\n",
181 | " n_estimators=100, n_jobs=None,\n",
182 | " oob_score=False, random_state=None,\n",
183 | " verbose=0, warm_start=False))],\n",
184 | " verbose=False)"
185 | ]
186 | },
187 | "execution_count": 6,
188 | "metadata": {
189 | "tags": []
190 | },
191 | "output_type": "execute_result"
192 | }
193 | ],
194 | "source": [
195 | "text_clf.fit(twenty_train.data, twenty_train.target)"
196 | ]
197 | },
198 | {
199 | "cell_type": "code",
200 | "execution_count": 7,
201 | "metadata": {
202 | "colab": {
203 | "base_uri": "https://localhost:8080/",
204 | "height": 34
205 | },
206 | "id": "YKrPF-VYR_1e",
207 | "outputId": "e0ae0588-d30d-4b34-b8ee-8c15bb49e6fe"
208 | },
209 | "outputs": [
210 | {
211 | "data": {
212 | "text/plain": [
213 | "0.8009320905459387"
214 | ]
215 | },
216 | "execution_count": 7,
217 | "metadata": {
218 | "tags": []
219 | },
220 | "output_type": "execute_result"
221 | }
222 | ],
223 | "source": [
224 | "import numpy as np\n",
225 | "twenty_test = fetch_20newsgroups(\n",
226 | " subset='test',\n",
227 | " categories=categories,\n",
228 | " shuffle=True,\n",
229 | " random_state=42\n",
230 | ")\n",
231 | "predicted = text_clf.predict(twenty_test.data)\n",
232 | "np.mean(predicted == twenty_test.target)"
233 | ]
234 | },
235 | {
236 | "cell_type": "markdown",
237 | "metadata": {
238 | "id": "AphYch6gSh5i"
239 | },
240 | "source": [
241 | "# Using a word embedding with a classifier"
242 | ]
243 | },
244 | {
245 | "cell_type": "code",
246 | "execution_count": 8,
247 | "metadata": {
248 | "colab": {
249 | "base_uri": "https://localhost:8080/",
250 | "height": 34
251 | },
252 | "id": "WMR8vAXBbhq0",
253 | "outputId": "69499c36-041a-4ec6-c5f7-36f0a122cd6b"
254 | },
255 | "outputs": [
256 | {
257 | "name": "stdout",
258 | "output_type": "stream",
259 | "text": [
260 | "Mounted at /gdrive\n"
261 | ]
262 | }
263 | ],
264 | "source": [
265 | "# if you use this in colab you can store the embeddings in your google drive. \n",
266 | "# This is optional. \n",
267 | "# If you don't want to do it, just remove the /gdrive part in the next cells\n",
268 | "from google.colab import drive\n",
269 | "drive.mount('/gdrive')"
270 | ]
271 | },
272 | {
273 | "cell_type": "code",
274 | "execution_count": 23,
275 | "metadata": {
276 | "colab": {
277 | "base_uri": "https://localhost:8080/",
278 | "height": 52
279 | },
280 | "id": "BBSlCzAWYmIz",
281 | "outputId": "dbfb86fa-dd8c-4812-bc69-20464b4540f9"
282 | },
283 | "outputs": [
284 | {
285 | "name": "stdout",
286 | "output_type": "stream",
287 | "text": [
288 | "Requirement already satisfied: wget in /usr/local/lib/python3.6/dist-packages (3.2)\n"
289 | ]
290 | },
291 | {
292 | "data": {
293 | "application/vnd.google.colaboratory.intrinsic+json": {
294 | "type": "string"
295 | },
296 | "text/plain": [
297 | "'/gdrive/My Drive/embeddings/wiki.en.vec'"
298 | ]
299 | },
300 | "execution_count": 23,
301 | "metadata": {
302 | "tags": []
303 | },
304 | "output_type": "execute_result"
305 | }
306 | ],
307 | "source": [
308 | "!pip install wget\n",
309 | "import os\n",
310 | "import wget\n",
311 | "filepath = '/gdrive/My Drive/embeddings/wiki.en.vec'\n",
312 | "if not os.path.isfile(filepath):\n",
313 | " wget.download(\n",
314 | " 'https://dl.fbaipublicfiles.com/fasttext/vectors-wiki/wiki.en.vec',\n",
315 | " filepath\n",
316 | " )"
317 | ]
318 | },
319 | {
320 | "cell_type": "code",
321 | "execution_count": 25,
322 | "metadata": {
323 | "colab": {
324 | "base_uri": "https://localhost:8080/",
325 | "height": 71
326 | },
327 | "id": "sIPfOVmoS5tB",
328 | "outputId": "b356945c-51fd-4fe6-bf73-81ef519d994d"
329 | },
330 | "outputs": [
331 | {
332 | "name": "stderr",
333 | "output_type": "stream",
334 | "text": [
335 | "/usr/local/lib/python3.6/dist-packages/smart_open/smart_open_lib.py:252: UserWarning: This function is deprecated, use smart_open.open instead. See the migration notes for details: https://github.com/RaRe-Technologies/smart_open/blob/master/README.rst#migrating-to-the-new-open-function\n",
336 | " 'See the migration notes for details: %s' % _MIGRATION_NOTES_URL\n"
337 | ]
338 | }
339 | ],
340 | "source": [
341 | "from gensim.models import KeyedVectors\n",
342 | "\n",
343 | "model = KeyedVectors.load_word2vec_format(\n",
344 | " filepath,\n",
345 | " binary=False, encoding='utf8'\n",
346 | ")"
347 | ]
348 | },
349 | {
350 | "cell_type": "code",
351 | "execution_count": 26,
352 | "metadata": {
353 | "colab": {
354 | "base_uri": "https://localhost:8080/",
355 | "height": 88
356 | },
357 | "id": "43F4upWkCFaB",
358 | "outputId": "73c3cb58-ddd1-4def-e7a0-fd3797bd0618"
359 | },
360 | "outputs": [
361 | {
362 | "name": "stderr",
363 | "output_type": "stream",
364 | "text": [
365 | "/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:7: DeprecationWarning: Call to deprecated `wv` (Attribute will be removed in 4.0.0, use self instead).\n",
366 | " import sys\n"
367 | ]
368 | },
369 | {
370 | "data": {
371 | "text/plain": [
372 | "(1, 300)"
373 | ]
374 | },
375 | "execution_count": 26,
376 | "metadata": {
377 | "tags": []
378 | },
379 | "output_type": "execute_result"
380 | }
381 | ],
382 | "source": [
383 | "import numpy as np\n",
384 | "from tensorflow.keras.preprocessing.text import text_to_word_sequence\n",
385 | "\n",
386 | "def embed_text(text: str):\n",
387 | " vector_list = [\n",
388 | " model.wv[w].reshape(-1, 1) for w in text_to_word_sequence(text)\n",
389 | " if w in model.wv\n",
390 | " ]\n",
391 | " if len(vector_list) > 0:\n",
392 | " return np.mean(\n",
393 | " np.concatenate(vector_list, axis=1),\n",
394 | " axis=1\n",
395 | " ).reshape(1, 300)\n",
396 | " else:\n",
397 | " return np.zeros(shape=(1, 300))\n",
398 | "\n",
399 | "\n",
400 | "embed_text('training run').shape"
401 | ]
402 | },
403 | {
404 | "cell_type": "code",
405 | "execution_count": 27,
406 | "metadata": {
407 | "colab": {
408 | "base_uri": "https://localhost:8080/",
409 | "height": 71
410 | },
411 | "id": "V3xnBruIOSxb",
412 | "outputId": "b8229485-b636-4be9-ccf2-6c27210e010a"
413 | },
414 | "outputs": [
415 | {
416 | "name": "stderr",
417 | "output_type": "stream",
418 | "text": [
419 | "/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:7: DeprecationWarning: Call to deprecated `wv` (Attribute will be removed in 4.0.0, use self instead).\n",
420 | " import sys\n"
421 | ]
422 | }
423 | ],
424 | "source": [
425 | "train_transformed = np.concatenate(\n",
426 | " [embed_text(t) for t in twenty_train.data]\n",
427 | ")"
428 | ]
429 | },
430 | {
431 | "cell_type": "code",
432 | "execution_count": 28,
433 | "metadata": {
434 | "colab": {
435 | "base_uri": "https://localhost:8080/",
436 | "height": 34
437 | },
438 | "id": "26NodEUqO_cB",
439 | "outputId": "66be743e-e982-49a8-bf8c-181c0beaa7db"
440 | },
441 | "outputs": [
442 | {
443 | "data": {
444 | "text/plain": [
445 | "(2257, 300)"
446 | ]
447 | },
448 | "execution_count": 28,
449 | "metadata": {
450 | "tags": []
451 | },
452 | "output_type": "execute_result"
453 | }
454 | ],
455 | "source": [
456 | "train_transformed.shape"
457 | ]
458 | },
459 | {
460 | "cell_type": "code",
461 | "execution_count": 29,
462 | "metadata": {
463 | "id": "JvhJL1PgOd14"
464 | },
465 | "outputs": [],
466 | "source": [
467 | "rf = RandomForestClassifier().fit(train_transformed, twenty_train.target)"
468 | ]
469 | },
470 | {
471 | "cell_type": "code",
472 | "execution_count": 30,
473 | "metadata": {
474 | "colab": {
475 | "base_uri": "https://localhost:8080/",
476 | "height": 71
477 | },
478 | "id": "wpU0y4CCO9ZX",
479 | "outputId": "4ffc4eef-b643-4b96-e286-48e8923c95a1"
480 | },
481 | "outputs": [
482 | {
483 | "name": "stderr",
484 | "output_type": "stream",
485 | "text": [
486 | "/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:7: DeprecationWarning: Call to deprecated `wv` (Attribute will be removed in 4.0.0, use self instead).\n",
487 | " import sys\n"
488 | ]
489 | }
490 | ],
491 | "source": [
492 | "test_transformed = np.concatenate(\n",
493 | " [embed_text(t) for t in twenty_test.data]\n",
494 | ")"
495 | ]
496 | },
497 | {
498 | "cell_type": "code",
499 | "execution_count": 31,
500 | "metadata": {
501 | "colab": {
502 | "base_uri": "https://localhost:8080/",
503 | "height": 34
504 | },
505 | "id": "DzFbM0RHPc1h",
506 | "outputId": "463868e0-fdba-4971-d816-551a7032abcc"
507 | },
508 | "outputs": [
509 | {
510 | "data": {
511 | "text/plain": [
512 | "0.8608521970705726"
513 | ]
514 | },
515 | "execution_count": 31,
516 | "metadata": {
517 | "tags": []
518 | },
519 | "output_type": "execute_result"
520 | }
521 | ],
522 | "source": [
523 | "predicted = rf.predict(test_transformed)\n",
524 | "np.mean(predicted == twenty_test.target)"
525 | ]
526 | },
527 | {
528 | "cell_type": "markdown",
529 | "metadata": {
530 | "id": "PY8WpBQ-eWLh"
531 | },
532 | "source": [
533 | "# Keras model with embedding layer"
534 | ]
535 | },
536 | {
537 | "cell_type": "code",
538 | "execution_count": 32,
539 | "metadata": {
540 | "id": "vrieNKV5e3Z8"
541 | },
542 | "outputs": [],
543 | "source": [
544 | "from tensorflow.keras import layers\n",
545 | "\n",
546 | "embedding = layers.Embedding(\n",
547 | " input_dim=5000, \n",
548 | " output_dim=50, \n",
549 | " input_length=500\n",
550 | ")"
551 | ]
552 | },
553 | {
554 | "cell_type": "code",
555 | "execution_count": 33,
556 | "metadata": {
557 | "id": "MInr1-EifFgQ"
558 | },
559 | "outputs": [],
560 | "source": [
561 | "from tensorflow.keras.preprocessing.text import Tokenizer\n",
562 | "\n",
563 | "tokenizer = Tokenizer(num_words=5000)\n",
564 | "tokenizer.fit_on_texts(twenty_train.data)"
565 | ]
566 | },
567 | {
568 | "cell_type": "code",
569 | "execution_count": 34,
570 | "metadata": {
571 | "id": "pEXw_q3yfae0"
572 | },
573 | "outputs": [],
574 | "source": [
575 | "X_train = tokenizer.texts_to_sequences(twenty_train.data)\n",
576 | "X_test = tokenizer.texts_to_sequences(twenty_test.data)"
577 | ]
578 | },
579 | {
580 | "cell_type": "code",
581 | "execution_count": 35,
582 | "metadata": {
583 | "id": "DIVyjcRAmX0m"
584 | },
585 | "outputs": [],
586 | "source": [
587 | "from tensorflow.keras.preprocessing.sequence import pad_sequences\n",
588 | "\n",
589 | "X_train = pad_sequences(X_train, padding='post', maxlen=500)\n",
590 | "X_test = pad_sequences(X_test, padding='post', maxlen=500)"
591 | ]
592 | },
593 | {
594 | "cell_type": "code",
595 | "execution_count": 36,
596 | "metadata": {
597 | "colab": {
598 | "base_uri": "https://localhost:8080/",
599 | "height": 289
600 | },
601 | "id": "jPBjYiageUbQ",
602 | "outputId": "d13e94a9-48f9-4b82-ae2d-bc8b57872e4f"
603 | },
604 | "outputs": [
605 | {
606 | "name": "stdout",
607 | "output_type": "stream",
608 | "text": [
609 | "Model: \"sequential\"\n",
610 | "_________________________________________________________________\n",
611 | "Layer (type) Output Shape Param # \n",
612 | "=================================================================\n",
613 | "embedding (Embedding) (None, 500, 50) 250000 \n",
614 | "_________________________________________________________________\n",
615 | "flatten (Flatten) (None, 25000) 0 \n",
616 | "_________________________________________________________________\n",
617 | "dense (Dense) (None, 10) 250010 \n",
618 | "_________________________________________________________________\n",
619 | "dense_1 (Dense) (None, 4) 44 \n",
620 | "=================================================================\n",
621 | "Total params: 500,054\n",
622 | "Trainable params: 500,054\n",
623 | "Non-trainable params: 0\n",
624 | "_________________________________________________________________\n"
625 | ]
626 | }
627 | ],
628 | "source": [
629 | "from tensorflow.keras.models import Sequential\n",
630 | "from tensorflow.keras.losses import SparseCategoricalCrossentropy\n",
631 | "from tensorflow.keras import regularizers\n",
632 | "\n",
633 | "model = Sequential()\n",
634 | "model.add(embedding)\n",
635 | "model.add(layers.Flatten())\n",
636 | "model.add(layers.Dense(\n",
637 | " 10,\n",
638 | " activation='relu',\n",
639 | " kernel_regularizer=regularizers.l1_l2(l1=1e-5, l2=1e-4)\n",
640 | "))\n",
641 | "model.add(layers.Dense(len(categories), activation='softmax'))\n",
642 | "model.compile(optimizer='adam',\n",
643 | " loss=SparseCategoricalCrossentropy(),\n",
644 | " metrics=['accuracy'])\n",
645 | "model.summary()"
646 | ]
647 | },
648 | {
649 | "cell_type": "code",
650 | "execution_count": 37,
651 | "metadata": {
652 | "colab": {
653 | "base_uri": "https://localhost:8080/",
654 | "height": 394
655 | },
656 | "id": "JVBg4RZTflSJ",
657 | "outputId": "ff0ac6a9-92a7-435a-8c53-2860261424e8"
658 | },
659 | "outputs": [
660 | {
661 | "name": "stdout",
662 | "output_type": "stream",
663 | "text": [
664 | "Epoch 1/10\n",
665 | "71/71 [==============================] - 3s 44ms/step - loss: 1.3417 - accuracy: 0.3345\n",
666 | "Epoch 2/10\n",
667 | "71/71 [==============================] - 4s 60ms/step - loss: 0.7579 - accuracy: 0.7461\n",
668 | "Epoch 3/10\n",
669 | "71/71 [==============================] - 4s 62ms/step - loss: 0.1984 - accuracy: 0.9730\n",
670 | "Epoch 4/10\n",
671 | "71/71 [==============================] - 5s 71ms/step - loss: 0.0950 - accuracy: 0.9969\n",
672 | "Epoch 5/10\n",
673 | "71/71 [==============================] - 5s 65ms/step - loss: 0.0707 - accuracy: 0.9987\n",
674 | "Epoch 6/10\n",
675 | "71/71 [==============================] - 5s 70ms/step - loss: 0.0583 - accuracy: 0.9996\n",
676 | "Epoch 7/10\n",
677 | "71/71 [==============================] - 5s 67ms/step - loss: 0.0514 - accuracy: 1.0000\n",
678 | "Epoch 8/10\n",
679 | "71/71 [==============================] - 5s 74ms/step - loss: 0.0463 - accuracy: 1.0000\n",
680 | "Epoch 9/10\n",
681 | "71/71 [==============================] - 5s 67ms/step - loss: 0.0422 - accuracy: 1.0000\n",
682 | "Epoch 10/10\n",
683 | "71/71 [==============================] - 5s 70ms/step - loss: 0.0388 - accuracy: 1.0000\n"
684 | ]
685 | },
686 | {
687 | "data": {
688 | "text/plain": [
689 | ""
690 | ]
691 | },
692 | "execution_count": 37,
693 | "metadata": {
694 | "tags": []
695 | },
696 | "output_type": "execute_result"
697 | }
698 | ],
699 | "source": [
700 | "model.fit(X_train, twenty_train.target, epochs=10)"
701 | ]
702 | },
703 | {
704 | "cell_type": "code",
705 | "execution_count": 38,
706 | "metadata": {
707 | "colab": {
708 | "base_uri": "https://localhost:8080/",
709 | "height": 34
710 | },
711 | "id": "XJ4HbaIYoWur",
712 | "outputId": "ef58ecce-57fd-47b0-dcb1-383a04c12318"
713 | },
714 | "outputs": [
715 | {
716 | "data": {
717 | "text/plain": [
718 | "0.8695073235685752"
719 | ]
720 | },
721 | "execution_count": 38,
722 | "metadata": {
723 | "tags": []
724 | },
725 | "output_type": "execute_result"
726 | }
727 | ],
728 | "source": [
729 | "predicted = model.predict(X_test).argmax(axis=1)\n",
730 | "np.mean(predicted == twenty_test.target)"
731 | ]
732 | },
733 | {
734 | "cell_type": "code",
735 | "execution_count": null,
736 | "metadata": {
737 | "id": "G1O2rHUKakdE"
738 | },
739 | "outputs": [],
740 | "source": []
741 | }
742 | ],
743 | "metadata": {
744 | "accelerator": "GPU",
745 | "colab": {
746 | "collapsed_sections": [],
747 | "name": "Classifying_newsgroups.ipynb",
748 | "provenance": [],
749 | "toc_visible": true
750 | },
751 | "kernelspec": {
752 | "display_name": "Python 3",
753 | "language": "python",
754 | "name": "python3"
755 | },
756 | "language_info": {
757 | "codemirror_mode": {
758 | "name": "ipython",
759 | "version": 3
760 | },
761 | "file_extension": ".py",
762 | "mimetype": "text/x-python",
763 | "name": "python",
764 | "nbconvert_exporter": "python",
765 | "pygments_lexer": "ipython3",
766 | "version": "3.6.11"
767 | }
768 | },
769 | "nbformat": 4,
770 | "nbformat_minor": 1
771 | }
772 |
--------------------------------------------------------------------------------
/chapter11/Serving a model for live-decisioning.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [
8 | {
9 | "name": "stdout",
10 | "output_type": "stream",
11 | "text": [
12 | "Collecting mlflow\n",
13 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/aa/2d/7fa1f6e310ded489d943ea20cd7977a9867cb8d81b526d9c9460ce4a5b39/mlflow-1.11.0-py3-none-any.whl (13.9MB)\n",
14 | "\u001b[K 100% |████████████████████████████████| 13.9MB 1.6MB/s ta 0:00:011 28% |█████████ | 3.9MB 8.5MB/s eta 0:00:02 37% |███████████▉ | 5.1MB 10.2MB/s eta 0:00:01\n",
15 | "\u001b[?25hRequirement already satisfied: protobuf>=3.6.0 in /Users/ben/anaconda3/lib/python3.6/site-packages (from mlflow) (3.12.2)\n",
16 | "Requirement already satisfied: sqlalchemy<=1.3.13 in /Users/ben/anaconda3/lib/python3.6/site-packages (from mlflow) (1.2.15)\n",
17 | "Requirement already satisfied: python-dateutil in /Users/ben/anaconda3/lib/python3.6/site-packages (from mlflow) (2.7.5)\n",
18 | "Requirement already satisfied: Flask in /Users/ben/anaconda3/lib/python3.6/site-packages (from mlflow) (1.0.2)\n",
19 | "Collecting querystring-parser (from mlflow)\n",
20 | " Downloading https://files.pythonhosted.org/packages/4a/fa/f54f5662e0eababf0c49e92fd94bf178888562c0e7b677c8941bbbcd1bd6/querystring_parser-1.2.4.tar.gz\n",
21 | "Requirement already satisfied: pandas in /Users/ben/anaconda3/lib/python3.6/site-packages (from mlflow) (1.1.0)\n",
22 | "Requirement already satisfied: numpy in /Users/ben/anaconda3/lib/python3.6/site-packages (from mlflow) (1.19.1)\n",
23 | "Requirement already satisfied: click>=7.0 in /Users/ben/anaconda3/lib/python3.6/site-packages (from mlflow) (7.0)\n",
24 | "Requirement already satisfied: cloudpickle in /Users/ben/anaconda3/lib/python3.6/site-packages (from mlflow) (1.3.0)\n",
25 | "Requirement already satisfied: requests>=2.17.3 in /Users/ben/anaconda3/lib/python3.6/site-packages (from mlflow) (2.24.0)\n",
26 | "Collecting sqlparse (from mlflow)\n",
27 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/85/ee/6e821932f413a5c4b76be9c5936e313e4fc626b33f16e027866e1d60f588/sqlparse-0.3.1-py2.py3-none-any.whl (40kB)\n",
28 | "\u001b[K 100% |████████████████████████████████| 40kB 4.5MB/s ta 0:00:01\n",
29 | "\u001b[?25hCollecting databricks-cli>=0.8.7 (from mlflow)\n",
30 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/1e/57/5c2d6b83cb8753d12f548e89f91037632baa8289677c1b2ab2adf14bf6b2/databricks-cli-0.11.0.tar.gz (49kB)\n",
31 | "\u001b[K 100% |████████████████████████████████| 51kB 9.2MB/s eta 0:00:01\n",
32 | "\u001b[?25hCollecting docker>=4.0.0 (from mlflow)\n",
33 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/9e/8c/8d42dbd83679483db207535f4fb02dc84325fa78b290f057694b057fcd21/docker-4.3.1-py2.py3-none-any.whl (145kB)\n",
34 | "\u001b[K 100% |████████████████████████████████| 153kB 25.3MB/s ta 0:00:01\n",
35 | "\u001b[?25hCollecting gitpython>=2.1.0 (from mlflow)\n",
36 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/f9/1e/a45320cab182bf1c8656107b3d4c042e659742822fc6bff150d769a984dd/GitPython-3.1.7-py3-none-any.whl (158kB)\n",
37 | "\u001b[K 100% |████████████████████████████████| 163kB 12.0MB/s ta 0:00:01\n",
38 | "\u001b[?25hRequirement already satisfied: entrypoints in /Users/ben/anaconda3/lib/python3.6/site-packages (from mlflow) (0.3)\n",
39 | "Collecting alembic<=1.4.1 (from mlflow)\n",
40 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/e0/e9/359dbb77c35c419df0aedeb1d53e71e7e3f438ff64a8fdb048c907404de3/alembic-1.4.1.tar.gz (1.1MB)\n",
41 | "\u001b[K 100% |████████████████████████████████| 1.1MB 6.2MB/s ta 0:00:011 3% |█▏ | 40kB 6.2MB/s eta 0:00:01\n",
42 | "\u001b[?25hRequirement already satisfied: six>=1.10.0 in /Users/ben/anaconda3/lib/python3.6/site-packages (from mlflow) (1.15.0)\n",
43 | "Collecting gunicorn; platform_system != \"Windows\" (from mlflow)\n",
44 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/69/ca/926f7cd3a2014b16870086b2d0fdc84a9e49473c68a8dff8b57f7c156f43/gunicorn-20.0.4-py2.py3-none-any.whl (77kB)\n",
45 | "\u001b[K 100% |████████████████████████████████| 81kB 1.7MB/s ta 0:00:01\n",
46 | "\u001b[?25hCollecting prometheus-flask-exporter (from mlflow)\n",
47 | " Downloading https://files.pythonhosted.org/packages/01/d8/ddc1e7397cd0e8503eb09399ee1a650ac9d9041ac39da78d0d426d8d85af/prometheus_flask_exporter-0.16.4.tar.gz\n",
48 | "Collecting azure-storage-blob>=12.0 (from mlflow)\n",
49 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/60/dc/3c25aeab827266c019e13abc07f31b5f47d93f1b8548a417c81c89c9d021/azure_storage_blob-12.4.0-py2.py3-none-any.whl (326kB)\n",
50 | "\u001b[K 100% |████████████████████████████████| 327kB 6.0MB/s ta 0:00:011\n",
51 | "\u001b[?25hRequirement already satisfied: pyyaml in /Users/ben/anaconda3/lib/python3.6/site-packages (from mlflow) (3.13)\n",
52 | "Collecting gorilla (from mlflow)\n",
53 | " Downloading https://files.pythonhosted.org/packages/e3/56/5a683944cbfc77e429c6f03c636ca50504a785f60ffae91ddd7f5f7bb520/gorilla-0.3.0-py2.py3-none-any.whl\n",
54 | "Requirement already satisfied: setuptools in /Users/ben/anaconda3/lib/python3.6/site-packages (from protobuf>=3.6.0->mlflow) (49.2.1)\n",
55 | "Requirement already satisfied: Werkzeug>=0.14 in /Users/ben/anaconda3/lib/python3.6/site-packages (from Flask->mlflow) (0.14.1)\n",
56 | "Requirement already satisfied: Jinja2>=2.10 in /Users/ben/anaconda3/lib/python3.6/site-packages (from Flask->mlflow) (2.11.2)\n",
57 | "Requirement already satisfied: itsdangerous>=0.24 in /Users/ben/anaconda3/lib/python3.6/site-packages (from Flask->mlflow) (1.1.0)\n",
58 | "Requirement already satisfied: pytz>=2017.2 in /Users/ben/anaconda3/lib/python3.6/site-packages (from pandas->mlflow) (2018.7)\n",
59 | "Requirement already satisfied: chardet<4,>=3.0.2 in /Users/ben/anaconda3/lib/python3.6/site-packages (from requests>=2.17.3->mlflow) (3.0.4)\n",
60 | "Requirement already satisfied: certifi>=2017.4.17 in /Users/ben/anaconda3/lib/python3.6/site-packages (from requests>=2.17.3->mlflow) (2020.6.20)\n",
61 | "Requirement already satisfied: idna<3,>=2.5 in /Users/ben/anaconda3/lib/python3.6/site-packages (from requests>=2.17.3->mlflow) (2.10)\n",
62 | "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /Users/ben/anaconda3/lib/python3.6/site-packages (from requests>=2.17.3->mlflow) (1.25.10)\n",
63 | "Requirement already satisfied: tabulate>=0.7.7 in /Users/ben/anaconda3/lib/python3.6/site-packages (from databricks-cli>=0.8.7->mlflow) (0.8.6)\n",
64 | "Requirement already satisfied: websocket-client>=0.32.0 in /Users/ben/anaconda3/lib/python3.6/site-packages (from docker>=4.0.0->mlflow) (0.54.0)\n",
65 | "Collecting gitdb<5,>=4.0.1 (from gitpython>=2.1.0->mlflow)\n",
66 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/48/11/d1800bca0a3bae820b84b7d813ad1eff15a48a64caea9c823fc8c1b119e8/gitdb-4.0.5-py3-none-any.whl (63kB)\n",
67 | "\u001b[K 100% |████████████████████████████████| 71kB 8.3MB/s ta 0:00:011\n",
68 | "\u001b[?25hRequirement already satisfied: Mako in /Users/ben/anaconda3/lib/python3.6/site-packages (from alembic<=1.4.1->mlflow) (1.0.7)\n",
69 | "Collecting python-editor>=0.3 (from alembic<=1.4.1->mlflow)\n",
70 | " Downloading https://files.pythonhosted.org/packages/c6/d3/201fc3abe391bbae6606e6f1d598c15d367033332bd54352b12f35513717/python_editor-1.0.4-py3-none-any.whl\n",
71 | "Requirement already satisfied: prometheus_client in /Users/ben/anaconda3/lib/python3.6/site-packages (from prometheus-flask-exporter->mlflow) (0.5.0)\n",
72 | "Collecting msrest>=0.6.10 (from azure-storage-blob>=12.0->mlflow)\n",
73 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/d9/ed/8e1b75721ad983c1672cd968ad3ae374d2e94767edff6f0b72a15dfde933/msrest-0.6.18-py2.py3-none-any.whl (84kB)\n",
74 | "\u001b[K 100% |████████████████████████████████| 92kB 7.9MB/s ta 0:00:011\n",
75 | "\u001b[?25hCollecting azure-core<2.0.0,>=1.6.0 (from azure-storage-blob>=12.0->mlflow)\n",
76 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/5b/e3/acb0c45c0b7afa33eca9fadb47f32cbfb95161e3719fcbe953d689b881bc/azure_core-1.8.0-py2.py3-none-any.whl (121kB)\n",
77 | "\u001b[K 100% |████████████████████████████████| 122kB 8.6MB/s eta 0:00:01\n",
78 | "\u001b[?25hRequirement already satisfied: cryptography>=2.1.4 in /Users/ben/anaconda3/lib/python3.6/site-packages (from azure-storage-blob>=12.0->mlflow) (2.6.1)\n",
79 | "Requirement already satisfied: MarkupSafe>=0.23 in /Users/ben/anaconda3/lib/python3.6/site-packages (from Jinja2>=2.10->Flask->mlflow) (1.1.0)\n",
80 | "Collecting smmap<4,>=3.0.1 (from gitdb<5,>=4.0.1->gitpython>=2.1.0->mlflow)\n",
81 | " Downloading https://files.pythonhosted.org/packages/b0/9a/4d409a6234eb940e6a78dfdfc66156e7522262f5f2fecca07dc55915952d/smmap-3.0.4-py2.py3-none-any.whl\n",
82 | "Collecting isodate>=0.6.0 (from msrest>=0.6.10->azure-storage-blob>=12.0->mlflow)\n",
83 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/9b/9f/b36f7774ff5ea8e428fdcfc4bb332c39ee5b9362ddd3d40d9516a55221b2/isodate-0.6.0-py2.py3-none-any.whl (45kB)\n"
84 | ]
85 | },
86 | {
87 | "name": "stdout",
88 | "output_type": "stream",
89 | "text": [
90 | "\u001b[K 100% |████████████████████████████████| 51kB 8.8MB/s eta 0:00:01\n",
91 | "\u001b[?25hRequirement already satisfied: requests-oauthlib>=0.5.0 in /Users/ben/anaconda3/lib/python3.6/site-packages (from msrest>=0.6.10->azure-storage-blob>=12.0->mlflow) (1.0.0)\n",
92 | "Requirement already satisfied: asn1crypto>=0.21.0 in /Users/ben/anaconda3/lib/python3.6/site-packages (from cryptography>=2.1.4->azure-storage-blob>=12.0->mlflow) (0.24.0)\n",
93 | "Requirement already satisfied: cffi!=1.11.3,>=1.8 in /Users/ben/anaconda3/lib/python3.6/site-packages (from cryptography>=2.1.4->azure-storage-blob>=12.0->mlflow) (1.14.1)\n",
94 | "Requirement already satisfied: oauthlib>=0.6.2 in /Users/ben/anaconda3/lib/python3.6/site-packages (from requests-oauthlib>=0.5.0->msrest>=0.6.10->azure-storage-blob>=12.0->mlflow) (2.1.0)\n",
95 | "Requirement already satisfied: pycparser in /Users/ben/anaconda3/lib/python3.6/site-packages (from cffi!=1.11.3,>=1.8->cryptography>=2.1.4->azure-storage-blob>=12.0->mlflow) (2.20)\n",
96 | "Building wheels for collected packages: querystring-parser, databricks-cli, alembic, prometheus-flask-exporter\n",
97 | " Running setup.py bdist_wheel for querystring-parser ... \u001b[?25ldone\n",
98 | "\u001b[?25h Stored in directory: /Users/ben/Library/Caches/pip/wheels/1e/41/34/23ebf5d1089a9aed847951e0ee375426eb4ad0a7079d88d41e\n",
99 | " Running setup.py bdist_wheel for databricks-cli ... \u001b[?25ldone\n",
100 | "\u001b[?25h Stored in directory: /Users/ben/Library/Caches/pip/wheels/63/d0/4f/3deeca1f4c47a6aca7c2c6a6e2bf272391565dc86a7718a59b\n",
101 | " Running setup.py bdist_wheel for alembic ... \u001b[?25ldone\n",
102 | "\u001b[?25h Stored in directory: /Users/ben/Library/Caches/pip/wheels/84/07/f7/12f7370ca47a66030c2edeedcc23dec26ea0ac22dcb4c4a0f3\n",
103 | " Running setup.py bdist_wheel for prometheus-flask-exporter ... \u001b[?25ldone\n",
104 | "\u001b[?25h Stored in directory: /Users/ben/Library/Caches/pip/wheels/f0/ba/c4/62b4809cae9e7823ae296909e887c1b60eade9415ca4a286c2\n",
105 | "Successfully built querystring-parser databricks-cli alembic prometheus-flask-exporter\n",
106 | "Installing collected packages: querystring-parser, sqlparse, databricks-cli, docker, smmap, gitdb, gitpython, python-editor, alembic, gunicorn, prometheus-flask-exporter, isodate, msrest, azure-core, azure-storage-blob, gorilla, mlflow\n",
107 | "Successfully installed alembic-1.4.1 azure-core-1.8.0 azure-storage-blob-12.4.0 databricks-cli-0.11.0 docker-4.3.1 gitdb-4.0.5 gitpython-3.1.7 gorilla-0.3.0 gunicorn-20.0.4 isodate-0.6.0 mlflow-1.11.0 msrest-0.6.18 prometheus-flask-exporter-0.16.4 python-editor-1.0.4 querystring-parser-1.2.4 smmap-3.0.4 sqlparse-0.3.1\n"
108 | ]
109 | }
110 | ],
111 | "source": [
112 | "!pip install mlflow"
113 | ]
114 | },
115 | {
116 | "cell_type": "code",
117 | "execution_count": null,
118 | "metadata": {},
119 | "outputs": [],
120 | "source": [
121 | "# this is adapted from https://github.com/mlflow/mlflow/blob/master/examples/sklearn_elasticnet_wine/train.ipynb"
122 | ]
123 | },
124 | {
125 | "cell_type": "code",
126 | "execution_count": 9,
127 | "metadata": {},
128 | "outputs": [],
129 | "source": [
130 | "import pandas as pd\n",
131 | "import numpy as np\n",
132 | "from sklearn.model_selection import train_test_split\n",
133 | "\n",
134 | "csv_url =\\\n",
135 | " 'http://archive.ics.uci.edu/ml/machine-' \\\n",
136 | " 'learning-databases/wine-quality/winequality-red.csv'\n",
137 | "data = pd.read_csv(csv_url, sep=';')"
138 | ]
139 | },
140 | {
141 | "cell_type": "code",
142 | "execution_count": 10,
143 | "metadata": {},
144 | "outputs": [
145 | {
146 | "data": {
147 | "text/html": [
148 | "\n",
149 | "\n",
162 | "
\n",
163 | " \n",
164 | " \n",
165 | " \n",
166 | " fixed acidity \n",
167 | " volatile acidity \n",
168 | " citric acid \n",
169 | " residual sugar \n",
170 | " chlorides \n",
171 | " free sulfur dioxide \n",
172 | " total sulfur dioxide \n",
173 | " density \n",
174 | " pH \n",
175 | " sulphates \n",
176 | " alcohol \n",
177 | " quality \n",
178 | " \n",
179 | " \n",
180 | " \n",
181 | " \n",
182 | " 0 \n",
183 | " 7.4 \n",
184 | " 0.70 \n",
185 | " 0.00 \n",
186 | " 1.9 \n",
187 | " 0.076 \n",
188 | " 11.0 \n",
189 | " 34.0 \n",
190 | " 0.9978 \n",
191 | " 3.51 \n",
192 | " 0.56 \n",
193 | " 9.4 \n",
194 | " 5 \n",
195 | " \n",
196 | " \n",
197 | " 1 \n",
198 | " 7.8 \n",
199 | " 0.88 \n",
200 | " 0.00 \n",
201 | " 2.6 \n",
202 | " 0.098 \n",
203 | " 25.0 \n",
204 | " 67.0 \n",
205 | " 0.9968 \n",
206 | " 3.20 \n",
207 | " 0.68 \n",
208 | " 9.8 \n",
209 | " 5 \n",
210 | " \n",
211 | " \n",
212 | " 2 \n",
213 | " 7.8 \n",
214 | " 0.76 \n",
215 | " 0.04 \n",
216 | " 2.3 \n",
217 | " 0.092 \n",
218 | " 15.0 \n",
219 | " 54.0 \n",
220 | " 0.9970 \n",
221 | " 3.26 \n",
222 | " 0.65 \n",
223 | " 9.8 \n",
224 | " 5 \n",
225 | " \n",
226 | " \n",
227 | " 3 \n",
228 | " 11.2 \n",
229 | " 0.28 \n",
230 | " 0.56 \n",
231 | " 1.9 \n",
232 | " 0.075 \n",
233 | " 17.0 \n",
234 | " 60.0 \n",
235 | " 0.9980 \n",
236 | " 3.16 \n",
237 | " 0.58 \n",
238 | " 9.8 \n",
239 | " 6 \n",
240 | " \n",
241 | " \n",
242 | " 4 \n",
243 | " 7.4 \n",
244 | " 0.70 \n",
245 | " 0.00 \n",
246 | " 1.9 \n",
247 | " 0.076 \n",
248 | " 11.0 \n",
249 | " 34.0 \n",
250 | " 0.9978 \n",
251 | " 3.51 \n",
252 | " 0.56 \n",
253 | " 9.4 \n",
254 | " 5 \n",
255 | " \n",
256 | " \n",
257 | "
\n",
258 | "
"
259 | ],
260 | "text/plain": [
261 | " fixed acidity volatile acidity citric acid residual sugar chlorides \\\n",
262 | "0 7.4 0.70 0.00 1.9 0.076 \n",
263 | "1 7.8 0.88 0.00 2.6 0.098 \n",
264 | "2 7.8 0.76 0.04 2.3 0.092 \n",
265 | "3 11.2 0.28 0.56 1.9 0.075 \n",
266 | "4 7.4 0.70 0.00 1.9 0.076 \n",
267 | "\n",
268 | " free sulfur dioxide total sulfur dioxide density pH sulphates \\\n",
269 | "0 11.0 34.0 0.9978 3.51 0.56 \n",
270 | "1 25.0 67.0 0.9968 3.20 0.68 \n",
271 | "2 15.0 54.0 0.9970 3.26 0.65 \n",
272 | "3 17.0 60.0 0.9980 3.16 0.58 \n",
273 | "4 11.0 34.0 0.9978 3.51 0.56 \n",
274 | "\n",
275 | " alcohol quality \n",
276 | "0 9.4 5 \n",
277 | "1 9.8 5 \n",
278 | "2 9.8 5 \n",
279 | "3 9.8 6 \n",
280 | "4 9.4 5 "
281 | ]
282 | },
283 | "execution_count": 10,
284 | "metadata": {},
285 | "output_type": "execute_result"
286 | }
287 | ],
288 | "source": [
289 | "data.head()"
290 | ]
291 | },
292 | {
293 | "cell_type": "code",
294 | "execution_count": 11,
295 | "metadata": {},
296 | "outputs": [],
297 | "source": [
298 | "# Split the data into training and test sets. (0.75, 0.25) split.\n",
299 | "# The predicted column is \"quality\" which is a scalar from [3, 9]\n",
300 | "train_x, test_x, train_y, test_y = train_test_split(\n",
301 | " data.drop(['quality'], axis=1),\n",
302 | " data['quality']\n",
303 | ")"
304 | ]
305 | },
306 | {
307 | "cell_type": "code",
308 | "execution_count": 12,
309 | "metadata": {},
310 | "outputs": [],
311 | "source": [
312 | "from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score\n",
313 | "\n",
314 | "def eval_metrics(actual, pred):\n",
315 | " rmse = np.sqrt(mean_squared_error(actual, pred))\n",
316 | " mae = mean_absolute_error(actual, pred)\n",
317 | " r2 = r2_score(actual, pred)\n",
318 | " return rmse, mae, r2"
319 | ]
320 | },
321 | {
322 | "cell_type": "code",
323 | "execution_count": 13,
324 | "metadata": {},
325 | "outputs": [],
326 | "source": [
327 | "import mlflow\n",
328 | "\n",
329 | "mlflow.set_tracking_uri('http://0.0.0.0:5000') # set to your server URI\n",
330 | "mlflow.set_experiment('/wine')"
331 | ]
332 | },
333 | {
334 | "cell_type": "code",
335 | "execution_count": 17,
336 | "metadata": {},
337 | "outputs": [],
338 | "source": [
339 | "from sklearn.linear_model import ElasticNet\n",
340 | "import mlflow.sklearn\n",
341 | "\n",
342 | "np.random.seed(40)\n",
343 | "\n",
344 | "def train(alpha=0.5, l1_ratio=0.5):\n",
345 | " with mlflow.start_run():\n",
346 | " lr = ElasticNet(alpha=alpha, l1_ratio=l1_ratio, random_state=42)\n",
347 | " lr.fit(train_x, train_y)\n",
348 | " predicted = lr.predict(test_x)\n",
349 | " rmse, mae, r2 = eval_metrics(test_y, predicted)\n",
350 | "\n",
351 | " model_name = lr.__class__.__name__\n",
352 | " print('{} (alpha={}, l1_ratio={}):'.format(\n",
353 | " model_name, alpha, l1_ratio\n",
354 | " ))\n",
355 | " print(' RMSE: %s' % rmse)\n",
356 | " print(' MAE: %s' % mae)\n",
357 | " print(' R2: %s' % r2)\n",
358 | "\n",
359 | " mlflow.log_params({key: value for key, value in lr.get_params().items()})\n",
360 | " mlflow.log_metric('rmse', rmse)\n",
361 | " mlflow.log_metric('r2', r2)\n",
362 | " mlflow.log_metric('mae', mae)\n",
363 | " mlflow.sklearn.log_model(lr, model_name)"
364 | ]
365 | },
366 | {
367 | "cell_type": "code",
368 | "execution_count": 18,
369 | "metadata": {},
370 | "outputs": [
371 | {
372 | "name": "stdout",
373 | "output_type": "stream",
374 | "text": [
375 | "ElasticNet (alpha=0.5, l1_ratio=0.5):\n",
376 | " RMSE: 0.7325693777577805\n",
377 | " MAE: 0.5895721434715478\n",
378 | " R2: 0.12163690293641838\n"
379 | ]
380 | }
381 | ],
382 | "source": [
383 | "train(0.5, 0.5)"
384 | ]
385 | },
386 | {
387 | "cell_type": "code",
388 | "execution_count": 19,
389 | "metadata": {},
390 | "outputs": [
391 | {
392 | "name": "stdout",
393 | "output_type": "stream",
394 | "text": [
395 | "ElasticNet (alpha=0.1, l1_ratio=0.5):\n",
396 | " RMSE: 0.6832521710295818\n",
397 | " MAE: 0.5350826216023779\n",
398 | " R2: 0.23592040719074103\n"
399 | ]
400 | }
401 | ],
402 | "source": [
403 | "train(0.1, 0.5)"
404 | ]
405 | },
406 | {
407 | "cell_type": "code",
408 | "execution_count": 20,
409 | "metadata": {},
410 | "outputs": [
411 | {
412 | "name": "stdout",
413 | "output_type": "stream",
414 | "text": [
415 | "ElasticNet (alpha=0.8, l1_ratio=0.5):\n",
416 | " RMSE: 0.7713038517785624\n",
417 | " MAE: 0.6344212065633348\n",
418 | " R2: 0.026294640912563283\n"
419 | ]
420 | }
421 | ],
422 | "source": [
423 | "train(0.8, 0.5)"
424 | ]
425 | },
426 | {
427 | "cell_type": "code",
428 | "execution_count": 21,
429 | "metadata": {},
430 | "outputs": [
431 | {
432 | "name": "stdout",
433 | "output_type": "stream",
434 | "text": [
435 | "ElasticNet (alpha=0.1, l1_ratio=0.2):\n",
436 | " RMSE: 0.6740753299699419\n",
437 | " MAE: 0.5276949437873688\n",
438 | " R2: 0.25630745861273185\n"
439 | ]
440 | }
441 | ],
442 | "source": [
443 | "train(0.1, 0.2)"
444 | ]
445 | },
446 | {
447 | "cell_type": "code",
448 | "execution_count": 22,
449 | "metadata": {},
450 | "outputs": [
451 | {
452 | "name": "stdout",
453 | "output_type": "stream",
454 | "text": [
455 | "ElasticNet (alpha=0.1, l1_ratio=0.3):\n",
456 | " RMSE: 0.6781545799635063\n",
457 | " MAE: 0.5308991080094628\n",
458 | " R2: 0.2472791287278865\n"
459 | ]
460 | }
461 | ],
462 | "source": [
463 | "train(0.1, 0.3)"
464 | ]
465 | },
466 | {
467 | "cell_type": "code",
468 | "execution_count": 23,
469 | "metadata": {},
470 | "outputs": [
471 | {
472 | "name": "stdout",
473 | "output_type": "stream",
474 | "text": [
475 | "ElasticNet (alpha=0.2, l1_ratio=0.2):\n",
476 | " RMSE: 0.6844856006568806\n",
477 | " MAE: 0.5375117838920673\n",
478 | " R2: 0.2331592331820277\n"
479 | ]
480 | }
481 | ],
482 | "source": [
483 | "train(0.2, 0.2)"
484 | ]
485 | },
486 | {
487 | "cell_type": "code",
488 | "execution_count": 24,
489 | "metadata": {},
490 | "outputs": [
491 | {
492 | "name": "stdout",
493 | "output_type": "stream",
494 | "text": [
495 | "ElasticNet (alpha=0.1, l1_ratio=0.1):\n",
496 | " RMSE: 0.6690250543541869\n",
497 | " MAE: 0.5236546308642179\n",
498 | " R2: 0.2674094302489908\n"
499 | ]
500 | }
501 | ],
502 | "source": [
503 | "train(0.1, 0.1)"
504 | ]
505 | },
506 | {
507 | "cell_type": "code",
508 | "execution_count": 25,
509 | "metadata": {},
510 | "outputs": [
511 | {
512 | "name": "stdout",
513 | "output_type": "stream",
514 | "text": [
515 | "ElasticNet (alpha=0.05, l1_ratio=0.05):\n",
516 | " RMSE: 0.6586566565045078\n",
517 | " MAE: 0.5152309297379374\n",
518 | " R2: 0.28994051940308996\n"
519 | ]
520 | }
521 | ],
522 | "source": [
523 | "train(0.05, 0.05)"
524 | ]
525 | },
526 | {
527 | "cell_type": "code",
528 | "execution_count": null,
529 | "metadata": {},
530 | "outputs": [],
531 | "source": []
532 | }
533 | ],
534 | "metadata": {
535 | "kernelspec": {
536 | "display_name": "Python 3",
537 | "language": "python",
538 | "name": "python3"
539 | },
540 | "language_info": {
541 | "codemirror_mode": {
542 | "name": "ipython",
543 | "version": 3
544 | },
545 | "file_extension": ".py",
546 | "mimetype": "text/x-python",
547 | "name": "python",
548 | "nbconvert_exporter": "python",
549 | "pygments_lexer": "ipython3",
550 | "version": "3.6.11"
551 | }
552 | },
553 | "nbformat": 4,
554 | "nbformat_minor": 2
555 | }
556 |
--------------------------------------------------------------------------------
/chapter11/visualizing_model_results.py:
--------------------------------------------------------------------------------
1 | """
2 | Visualizing the results of a classification model in streamlit.
3 | """
4 | import numpy as np
5 | import pandas as pd
6 | import altair as alt
7 | import streamlit as st
8 |
9 | from sklearn.datasets import (
10 | load_iris,
11 | load_wine,
12 | fetch_covtype
13 | )
14 | from sklearn.model_selection import train_test_split
15 | from sklearn.ensemble import (
16 | RandomForestClassifier,
17 | ExtraTreesClassifier,
18 | )
19 | from sklearn.tree import DecisionTreeClassifier
20 | from sklearn.metrics import roc_auc_score
21 | from sklearn.metrics import classification_report
22 |
23 |
24 | dataset_lookup = {
25 | 'Iris': load_iris,
26 | 'Wine': load_wine,
27 | 'Covertype': fetch_covtype,
28 | }
29 |
30 |
31 | @st.cache
32 | def load_data(name):
33 | iris = dataset_lookup[name]()
34 | X_train, X_test, y_train, y_test = train_test_split(
35 | iris.data, iris.target, test_size=0.33, random_state=42
36 | )
37 | feature_names = getattr(
38 | iris, 'feature_names',
39 | [str(i) for i in range(X_train.shape[1])]
40 | )
41 | target_names = getattr(
42 | iris, 'target_names',
43 | [str(i) for i in np.unique(iris.target)]
44 | )
45 | return (
46 | X_train, X_test, y_train, y_test,
47 | target_names, feature_names
48 | )
49 |
50 |
51 | @st.cache
52 | def train_model(dataset_name, model_name, n_estimators, max_depth):
53 | model = [m for m in models if m.__class__.__name__ == model_name][0]
54 | with st.spinner('Building a {} model for {} ...'.format(
55 | model_name, dataset_name
56 | )):
57 | return model.fit(X_train, y_train)
58 |
59 |
60 | # sidebar:
61 | st.sidebar.title('Model and dataset selection')
62 | dataset_name = st.sidebar.selectbox(
63 | 'Dataset',
64 | list(dataset_lookup.keys())
65 | )
66 | (X_train, X_test, y_train, y_test,
67 | target_names, feature_names) = load_data(dataset_name)
68 |
69 | n_estimators = st.sidebar.slider(
70 | 'n_estimators',
71 | 1, 100, 25
72 | )
73 | max_depth = st.sidebar.slider(
74 | 'max_depth',
75 | 1, 150, 10
76 | )
77 | models = [
78 | DecisionTreeClassifier(max_depth=max_depth),
79 | RandomForestClassifier(
80 | n_estimators=n_estimators,
81 | max_depth=max_depth
82 | ),
83 | ExtraTreesClassifier(
84 | n_estimators=n_estimators,
85 | max_depth=max_depth
86 | ),
87 | ]
88 | model_name = st.sidebar.selectbox(
89 | 'Model',
90 | [m.__class__.__name__ for m in models]
91 | )
92 | model = train_model(dataset_name, model_name, n_estimators, max_depth)
93 |
94 |
95 | # main content:
96 | st.title('{model} on {dataset}'.format(
97 | model=model_name,
98 | dataset=dataset_name
99 | ))
100 |
101 | predictions = model.predict(X_test)
102 | probs = model.predict_proba(X_test)
103 | st.subheader('Model performance in test')
104 | st.write('AUC: {:.2f}'.format(
105 | roc_auc_score(
106 | y_test, probs,
107 | multi_class='ovo' if len(target_names) > 2 else 'raise',
108 | average='macro' if len(target_names) > 2 else None
109 | )
110 | ))
111 | st.write(
112 | pd.DataFrame(
113 | classification_report(
114 | y_test, predictions,
115 | target_names=target_names,
116 | output_dict=True
117 | )
118 | )
119 | )
120 | test_df = pd.DataFrame(
121 | data=np.concatenate([
122 | X_test,
123 | y_test.reshape(-1, 1),
124 | predictions.reshape(-1, 1)
125 | ], axis=1),
126 | columns=feature_names + [
127 | 'target', 'predicted'
128 | ]
129 | )
130 | target_map = {i: n for i, n in enumerate(target_names)}
131 | test_df.target = test_df.target.map(target_map)
132 | test_df.predicted = test_df.predicted.map(target_map)
133 | confusion_matrix = pd.crosstab(
134 | test_df['target'],
135 | test_df['predicted'],
136 | rownames=['Actual'],
137 | colnames=['Predicted']
138 | )
139 | st.subheader('Confusion Matrix')
140 | st.write(confusion_matrix)
141 |
142 |
143 | def highlight_error(s):
144 | if s.predicted == s.target:
145 | return ['background-color: None'] * len(s)
146 | return ['background-color: red'] * len(s)
147 |
148 |
149 | if st.checkbox('Show test data'):
150 | st.subheader('Test data')
151 | st.write(test_df.style.apply(highlight_error, axis=1))
152 |
153 |
154 | if st.checkbox('Show test distributions'):
155 | st.subheader('Distributions')
156 | row_features = feature_names[:len(feature_names)//2]
157 | col_features = feature_names[len(row_features):]
158 | test_df_with_error = test_df.copy()
159 | test_df_with_error['error'] = test_df.predicted == test_df.target
160 | chart = alt.Chart(test_df_with_error).mark_circle().encode(
161 | alt.X(alt.repeat("column"), type='quantitative'),
162 | alt.Y(alt.repeat("row"), type='quantitative'),
163 | color='error:N'
164 | ).properties(
165 | width=250,
166 | height=250
167 | ).repeat(
168 | row=row_features,
169 | column=col_features
170 | ).interactive()
171 | st.write(chart)
--------------------------------------------------------------------------------