├── model.png ├── 02 model.png ├── movielens.zip ├── data ├── mldataset │ ├── movies.dat │ └── README ├── gender.dat ├── training_match.json └── users_model_customers.json ├── .gitattributes ├── .gitignore ├── __pycache__ ├── metrics.cpython-37.pyc └── preprocessing_data.cpython-37.pyc ├── README.md ├── metrics.py ├── preprocessing_data.py └── .ipynb_checkpoints ├── PySparkRecSys-checkpoint.ipynb ├── RecSysKeras-checkpoint.ipynb └── RecSysTripletLoss-checkpoint.ipynb /model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oscar-defelice/DeepRecSys/master/model.png -------------------------------------------------------------------------------- /02 model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oscar-defelice/DeepRecSys/master/02 model.png -------------------------------------------------------------------------------- /movielens.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oscar-defelice/DeepRecSys/master/movielens.zip -------------------------------------------------------------------------------- /data/mldataset/movies.dat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oscar-defelice/DeepRecSys/master/data/mldataset/movies.dat -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | data/* filter=lfs diff=lfs merge=lfs -text 2 | data/ratings.dat filter=lfs diff=lfs merge=lfs -text 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # OS files 2 | .DS_Store 3 | ._.DS_Store 4 | **/.DS_Store 5 | **/._.DS_Store 6 | 7 | ./data/ratings.dat -------------------------------------------------------------------------------- /__pycache__/metrics.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oscar-defelice/DeepRecSys/master/__pycache__/metrics.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/preprocessing_data.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oscar-defelice/DeepRecSys/master/__pycache__/preprocessing_data.cpython-37.pyc -------------------------------------------------------------------------------- /data/gender.dat: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:8bf68902af27d170705699a30defeca28ae4fce705a6b0f2b3027d2aa2f46ffa 3 | size 2098595 4 | -------------------------------------------------------------------------------- /data/training_match.json: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:65161de5649844119660b9b9f3dc08fdeae71639466e12cd373e2b87c5ece71e 3 | size 8688884 4 | -------------------------------------------------------------------------------- /data/users_model_customers.json: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:e75ee7cccaade417409ad82e3a424decb727df98c6e0bf3b784c75980133a0db 3 | size 5848848 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Recommender System 2 | 3 | Codes to buil an hybrid recommender system based on Keras. 4 | 5 | The aim of this repository is to collect the codeto leverage a deep learning framework to create a hybrid recommender system *i.e.* a model exploiting both content and collaborative-filter data. 6 | 7 | The idea is to tackle issues in two different steps: first collaborative filtering and content based model separately, then a combination of the two, to get better results. 8 | 9 | ## Medium post 10 | 11 | A somehow detailed description of the code can be found in [this medium post](https://medium.com/deep-recommender-system/a-deep-recommender-system-e2b765d27350). 12 | 13 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.metrics import roc_auc_score 3 | 4 | 5 | def predict(model, uid, pids): 6 | 7 | user_vector = model.get_layer('user_emb').get_weights()[0][uid] 8 | item_matrix = model.get_layer('item_emb').get_weights()[0][pids] 9 | 10 | scores = (np.dot(user_vector, 11 | item_matrix.T)) 12 | 13 | return scores 14 | 15 | 16 | def precision_at_k(model, ground_truth, k, user_features=None, item_features=None): 17 | """ 18 | Measure precision at k for model and ground truth. 19 | Arguments: 20 | - lightFM instance model 21 | - sparse matrix ground_truth (no_users, no_items) 22 | - int k 23 | Returns: 24 | - float precision@k 25 | """ 26 | 27 | ground_truth = ground_truth.tocsr() 28 | 29 | no_users, no_items = ground_truth.shape 30 | 31 | pid_array = np.arange(no_items, dtype=np.int32) 32 | 33 | precisions = [] 34 | 35 | for user_id, row in enumerate(ground_truth): 36 | uid_array = np.empty(no_items, dtype=np.int32) 37 | uid_array.fill(user_id) 38 | predictions = model.predict(uid_array, pid_array, 39 | user_features=user_features, 40 | item_features=item_features, 41 | num_threads=4) 42 | 43 | top_k = set(np.argsort(-predictions)[:k]) 44 | true_pids = set(row.indices[row.data == 1]) 45 | 46 | if true_pids: 47 | precisions.append(len(top_k & true_pids) / float(k)) 48 | 49 | return sum(precisions) / len(precisions) 50 | 51 | 52 | def full_auc(model, ground_truth): 53 | """ 54 | Measure AUC for model and ground truth on all items. 55 | Returns: 56 | - float AUC 57 | """ 58 | 59 | ground_truth = ground_truth.tocsr() 60 | 61 | no_users, no_items = ground_truth.shape 62 | 63 | pid_array = np.arange(no_items, dtype=np.int32) 64 | 65 | scores = [] 66 | 67 | for user_id, row in enumerate(ground_truth): 68 | 69 | predictions = predict(model, user_id, pid_array) 70 | 71 | true_pids = row.indices[row.data == 1] 72 | 73 | grnd = np.zeros(no_items, dtype=np.int32) 74 | grnd[true_pids] = 1 75 | 76 | if len(true_pids): 77 | scores.append(roc_auc_score(grnd, predictions)) 78 | 79 | return sum(scores) / len(scores) -------------------------------------------------------------------------------- /preprocessing_data.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import os 3 | import zipfile 4 | 5 | import numpy as np 6 | 7 | import requests 8 | 9 | import scipy.sparse as sp 10 | 11 | 12 | def _get_movielens_path(): 13 | """ 14 | Get path to the movielens dataset file. 15 | """ 16 | 17 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), 18 | 'movielens.zip') 19 | 20 | 21 | def _download_movielens(dest_path): 22 | """ 23 | Download the dataset. 24 | """ 25 | 26 | url = 'http://files.grouplens.org/datasets/movielens/ml-100k.zip' 27 | req = requests.get(url, stream=True) 28 | 29 | print('Downloading MovieLens data') 30 | 31 | with open(dest_path, 'wb') as fd: 32 | for chunk in req.iter_content(): 33 | fd.write(chunk) 34 | 35 | 36 | def _get_raw_movielens_data(): 37 | """ 38 | Return the raw lines of the train and test files. 39 | """ 40 | 41 | path = _get_movielens_path() 42 | 43 | if not os.path.isfile(path): 44 | _download_movielens(path) 45 | 46 | with zipfile.ZipFile(path) as datafile: 47 | return (datafile.read('ml-100k/ua.base').decode().split('\n'), 48 | datafile.read('ml-100k/ua.test').decode().split('\n')) 49 | 50 | 51 | def _parse(data): 52 | """ 53 | Parse movielens dataset lines. 54 | """ 55 | 56 | for line in data: 57 | 58 | if not line: 59 | continue 60 | 61 | uid, iid, rating, timestamp = [int(x) for x in line.split('\t')] 62 | 63 | yield uid, iid, rating, timestamp 64 | 65 | 66 | def _build_interaction_matrix(rows, cols, data): 67 | 68 | mat = sp.lil_matrix((rows, cols), dtype=np.int32) 69 | 70 | for uid, iid, rating, timestamp in data: 71 | # Let's assume only really good things are positives 72 | if rating >= 4.0: 73 | mat[uid, iid] = 1.0 74 | 75 | return mat.tocoo() 76 | 77 | 78 | def _get_movie_raw_metadata(): 79 | """ 80 | Get raw lines of the genre file. 81 | """ 82 | 83 | path = _get_movielens_path() 84 | 85 | if not os.path.isfile(path): 86 | _download_movielens(path) 87 | 88 | with zipfile.ZipFile(path) as datafile: 89 | return datafile.read('ml-100k/u.item').decode(errors='ignore').split('\n') 90 | 91 | 92 | def get_movielens_item_metadata(use_item_ids): 93 | """ 94 | Build a matrix of genre features (no_items, no_features). 95 | If use_item_ids is True, per-item feeatures will also be used. 96 | """ 97 | 98 | features = {} 99 | genre_set = set() 100 | 101 | for line in _get_movie_raw_metadata(): 102 | 103 | if not line: 104 | continue 105 | 106 | splt = line.split('|') 107 | item_id = int(splt[0]) 108 | 109 | genres = [idx for idx, val in 110 | zip(range(len(splt[5:])), splt[5:]) 111 | if int(val) > 0] 112 | 113 | if use_item_ids: 114 | # Add item-specific features too 115 | genres.append(item_id) 116 | 117 | for genre_id in genres: 118 | genre_set.add(genre_id) 119 | 120 | features[item_id] = genres 121 | 122 | mat = sp.lil_matrix((len(features) + 1, 123 | len(genre_set)), 124 | dtype=np.int32) 125 | 126 | for item_id, genre_ids in features.items(): 127 | for genre_id in genre_ids: 128 | mat[item_id, genre_id] = 1 129 | 130 | return mat 131 | 132 | 133 | def get_dense_triplets(uids, pids, nids, num_users, num_items): 134 | 135 | user_identity = np.identity(num_users) 136 | item_identity = np.identity(num_items) 137 | 138 | return user_identity[uids], item_identity[pids], item_identity[nids] 139 | 140 | 141 | def get_triplets(mat): 142 | 143 | return mat.row, mat.col, np.random.randint(mat.shape[1], size=len(mat.row)) 144 | 145 | 146 | def get_movielens_data(): 147 | """ 148 | Return (train_interactions, test_interactions). 149 | """ 150 | 151 | train_data, test_data = _get_raw_movielens_data() 152 | 153 | uids = set() 154 | iids = set() 155 | 156 | for uid, iid, rating, timestamp in itertools.chain(_parse(train_data), 157 | _parse(test_data)): 158 | uids.add(uid) 159 | iids.add(iid) 160 | 161 | rows = max(uids) + 1 162 | cols = max(iids) + 1 163 | 164 | return (_build_interaction_matrix(rows, cols, _parse(train_data)), 165 | _build_interaction_matrix(rows, cols, _parse(test_data))) -------------------------------------------------------------------------------- /data/mldataset/README: -------------------------------------------------------------------------------- 1 | SUMMARY 2 | ================================================================================ 3 | 4 | These files contain 1,000,209 anonymous ratings of approximately 3,900 movies 5 | made by 6,040 MovieLens users who joined MovieLens in 2000. 6 | 7 | USAGE LICENSE 8 | ================================================================================ 9 | 10 | Neither the University of Minnesota nor any of the researchers 11 | involved can guarantee the correctness of the data, its suitability 12 | for any particular purpose, or the validity of results based on the 13 | use of the data set. The data set may be used for any research 14 | purposes under the following conditions: 15 | 16 | * The user may not state or imply any endorsement from the 17 | University of Minnesota or the GroupLens Research Group. 18 | 19 | * The user must acknowledge the use of the data set in 20 | publications resulting from the use of the data set 21 | (see below for citation information). 22 | 23 | * The user may not redistribute the data without separate 24 | permission. 25 | 26 | * The user may not use this information for any commercial or 27 | revenue-bearing purposes without first obtaining permission 28 | from a faculty member of the GroupLens Research Project at the 29 | University of Minnesota. 30 | 31 | If you have any further questions or comments, please contact GroupLens 32 | . 33 | 34 | CITATION 35 | ================================================================================ 36 | 37 | To acknowledge use of the dataset in publications, please cite the following 38 | paper: 39 | 40 | F. Maxwell Harper and Joseph A. Konstan. 2015. The MovieLens Datasets: History 41 | and Context. ACM Transactions on Interactive Intelligent Systems (TiiS) 5, 4, 42 | Article 19 (December 2015), 19 pages. DOI=http://dx.doi.org/10.1145/2827872 43 | 44 | 45 | ACKNOWLEDGEMENTS 46 | ================================================================================ 47 | 48 | Thanks to Shyong Lam and Jon Herlocker for cleaning up and generating the data 49 | set. 50 | 51 | FURTHER INFORMATION ABOUT THE GROUPLENS RESEARCH PROJECT 52 | ================================================================================ 53 | 54 | The GroupLens Research Project is a research group in the Department of 55 | Computer Science and Engineering at the University of Minnesota. Members of 56 | the GroupLens Research Project are involved in many research projects related 57 | to the fields of information filtering, collaborative filtering, and 58 | recommender systems. The project is lead by professors John Riedl and Joseph 59 | Konstan. The project began to explore automated collaborative filtering in 60 | 1992, but is most well known for its world wide trial of an automated 61 | collaborative filtering system for Usenet news in 1996. Since then the project 62 | has expanded its scope to research overall information filtering solutions, 63 | integrating in content-based methods as well as improving current collaborative 64 | filtering technology. 65 | 66 | Further information on the GroupLens Research project, including research 67 | publications, can be found at the following web site: 68 | 69 | http://www.grouplens.org/ 70 | 71 | GroupLens Research currently operates a movie recommender based on 72 | collaborative filtering: 73 | 74 | http://www.movielens.org/ 75 | 76 | RATINGS FILE DESCRIPTION 77 | ================================================================================ 78 | 79 | All ratings are contained in the file "ratings.dat" and are in the 80 | following format: 81 | 82 | UserID::MovieID::Rating::Timestamp 83 | 84 | - UserIDs range between 1 and 6040 85 | - MovieIDs range between 1 and 3952 86 | - Ratings are made on a 5-star scale (whole-star ratings only) 87 | - Timestamp is represented in seconds since the epoch as returned by time(2) 88 | - Each user has at least 20 ratings 89 | 90 | USERS FILE DESCRIPTION 91 | ================================================================================ 92 | 93 | User information is in the file "users.dat" and is in the following 94 | format: 95 | 96 | UserID::Gender::Age::Occupation::Zip-code 97 | 98 | All demographic information is provided voluntarily by the users and is 99 | not checked for accuracy. Only users who have provided some demographic 100 | information are included in this data set. 101 | 102 | - Gender is denoted by a "M" for male and "F" for female 103 | - Age is chosen from the following ranges: 104 | 105 | * 1: "Under 18" 106 | * 18: "18-24" 107 | * 25: "25-34" 108 | * 35: "35-44" 109 | * 45: "45-49" 110 | * 50: "50-55" 111 | * 56: "56+" 112 | 113 | - Occupation is chosen from the following choices: 114 | 115 | * 0: "other" or not specified 116 | * 1: "academic/educator" 117 | * 2: "artist" 118 | * 3: "clerical/admin" 119 | * 4: "college/grad student" 120 | * 5: "customer service" 121 | * 6: "doctor/health care" 122 | * 7: "executive/managerial" 123 | * 8: "farmer" 124 | * 9: "homemaker" 125 | * 10: "K-12 student" 126 | * 11: "lawyer" 127 | * 12: "programmer" 128 | * 13: "retired" 129 | * 14: "sales/marketing" 130 | * 15: "scientist" 131 | * 16: "self-employed" 132 | * 17: "technician/engineer" 133 | * 18: "tradesman/craftsman" 134 | * 19: "unemployed" 135 | * 20: "writer" 136 | 137 | MOVIES FILE DESCRIPTION 138 | ================================================================================ 139 | 140 | Movie information is in the file "movies.dat" and is in the following 141 | format: 142 | 143 | MovieID::Title::Genres 144 | 145 | - Titles are identical to titles provided by the IMDB (including 146 | year of release) 147 | - Genres are pipe-separated and are selected from the following genres: 148 | 149 | * Action 150 | * Adventure 151 | * Animation 152 | * Children's 153 | * Comedy 154 | * Crime 155 | * Documentary 156 | * Drama 157 | * Fantasy 158 | * Film-Noir 159 | * Horror 160 | * Musical 161 | * Mystery 162 | * Romance 163 | * Sci-Fi 164 | * Thriller 165 | * War 166 | * Western 167 | 168 | - Some MovieIDs do not correspond to a movie due to accidental duplicate 169 | entries and/or test entries 170 | - Movies are mostly entered by hand, so errors and inconsistencies may exist 171 | -------------------------------------------------------------------------------- /.ipynb_checkpoints/PySparkRecSys-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "# Import libraries\n", 10 | "\n", 11 | "import tensorflow as tf\n", 12 | "\n", 13 | "from tensorflow import keras\n", 14 | "\n", 15 | "import os\n", 16 | "import numpy as np\n", 17 | "import pandas as pd\n", 18 | "import matplotlib.pyplot as plt\n", 19 | "\n", 20 | "\n", 21 | "from tqdm import tqdm\n", 22 | "\n", 23 | "from pyspark.sql import SparkSession\n", 24 | "spark = SparkSession.builder.appName('Recommendation_system').getOrCreate()\n", 25 | "\n", 26 | "from pyspark.sql.types import StructType, StructField\n", 27 | "from pyspark.sql.types import DoubleType, IntegerType, StringType, DateType" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": 2, 33 | "metadata": {}, 34 | "outputs": [ 35 | { 36 | "name": "stdout", 37 | "output_type": "stream", 38 | "text": [ 39 | "+---+----+----+----+---+----+---------+\n", 40 | "|_c0| _c1| _c2| _c3|_c4| _c5| _c6|\n", 41 | "+---+----+----+----+---+----+---------+\n", 42 | "| 1|null|1193|null| 5|null|978300760|\n", 43 | "| 1|null| 661|null| 3|null|978302109|\n", 44 | "| 1|null| 914|null| 3|null|978301968|\n", 45 | "| 1|null|3408|null| 4|null|978300275|\n", 46 | "| 1|null|2355|null| 5|null|978824291|\n", 47 | "| 1|null|1197|null| 3|null|978302268|\n", 48 | "| 1|null|1287|null| 5|null|978302039|\n", 49 | "| 1|null|2804|null| 5|null|978300719|\n", 50 | "| 1|null| 594|null| 4|null|978302268|\n", 51 | "| 1|null| 919|null| 4|null|978301368|\n", 52 | "| 1|null| 595|null| 5|null|978824268|\n", 53 | "| 1|null| 938|null| 4|null|978301752|\n", 54 | "| 1|null|2398|null| 4|null|978302281|\n", 55 | "| 1|null|2918|null| 4|null|978302124|\n", 56 | "| 1|null|1035|null| 5|null|978301753|\n", 57 | "| 1|null|2791|null| 4|null|978302188|\n", 58 | "| 1|null|2687|null| 3|null|978824268|\n", 59 | "| 1|null|2018|null| 4|null|978301777|\n", 60 | "| 1|null|3105|null| 5|null|978301713|\n", 61 | "| 1|null|2797|null| 4|null|978302039|\n", 62 | "| 1|null|2321|null| 3|null|978302205|\n", 63 | "| 1|null| 720|null| 3|null|978300760|\n", 64 | "| 1|null|1270|null| 5|null|978300055|\n", 65 | "| 1|null| 527|null| 5|null|978824195|\n", 66 | "| 1|null|2340|null| 3|null|978300103|\n", 67 | "| 1|null| 48|null| 5|null|978824351|\n", 68 | "| 1|null|1097|null| 4|null|978301953|\n", 69 | "| 1|null|1721|null| 4|null|978300055|\n", 70 | "| 1|null|1545|null| 4|null|978824139|\n", 71 | "| 1|null| 745|null| 3|null|978824268|\n", 72 | "| 1|null|2294|null| 4|null|978824291|\n", 73 | "| 1|null|3186|null| 4|null|978300019|\n", 74 | "| 1|null|1566|null| 4|null|978824330|\n", 75 | "| 1|null| 588|null| 4|null|978824268|\n", 76 | "| 1|null|1907|null| 4|null|978824330|\n", 77 | "| 1|null| 783|null| 4|null|978824291|\n", 78 | "| 1|null|1836|null| 5|null|978300172|\n", 79 | "| 1|null|1022|null| 5|null|978300055|\n", 80 | "| 1|null|2762|null| 4|null|978302091|\n", 81 | "| 1|null| 150|null| 5|null|978301777|\n", 82 | "| 1|null| 1|null| 5|null|978824268|\n", 83 | "| 1|null|1961|null| 5|null|978301590|\n", 84 | "| 1|null|1962|null| 4|null|978301753|\n", 85 | "| 1|null|2692|null| 4|null|978301570|\n", 86 | "| 1|null| 260|null| 4|null|978300760|\n", 87 | "| 1|null|1028|null| 5|null|978301777|\n", 88 | "| 1|null|1029|null| 5|null|978302205|\n", 89 | "| 1|null|1207|null| 4|null|978300719|\n", 90 | "| 1|null|2028|null| 5|null|978301619|\n", 91 | "| 1|null| 531|null| 4|null|978302149|\n", 92 | "| 1|null|3114|null| 4|null|978302174|\n", 93 | "| 1|null| 608|null| 4|null|978301398|\n", 94 | "| 1|null|1246|null| 4|null|978302091|\n", 95 | "| 2|null|1357|null| 5|null|978298709|\n", 96 | "| 2|null|3068|null| 4|null|978299000|\n", 97 | "| 2|null|1537|null| 4|null|978299620|\n", 98 | "| 2|null| 647|null| 3|null|978299351|\n", 99 | "| 2|null|2194|null| 4|null|978299297|\n", 100 | "| 2|null| 648|null| 4|null|978299913|\n", 101 | "| 2|null|2268|null| 5|null|978299297|\n", 102 | "| 2|null|2628|null| 3|null|978300051|\n", 103 | "| 2|null|1103|null| 3|null|978298905|\n", 104 | "| 2|null|2916|null| 3|null|978299809|\n", 105 | "| 2|null|3468|null| 5|null|978298542|\n", 106 | "| 2|null|1210|null| 4|null|978298151|\n", 107 | "| 2|null|1792|null| 3|null|978299941|\n", 108 | "| 2|null|1687|null| 3|null|978300174|\n", 109 | "| 2|null|1213|null| 2|null|978298458|\n", 110 | "| 2|null|3578|null| 5|null|978298958|\n", 111 | "| 2|null|2881|null| 3|null|978300002|\n", 112 | "| 2|null|3030|null| 4|null|978298434|\n", 113 | "| 2|null|1217|null| 3|null|978298151|\n", 114 | "| 2|null|3105|null| 4|null|978298673|\n", 115 | "| 2|null| 434|null| 2|null|978300174|\n", 116 | "| 2|null|2126|null| 3|null|978300123|\n", 117 | "| 2|null|3107|null| 2|null|978300002|\n", 118 | "| 2|null|3108|null| 3|null|978299712|\n", 119 | "| 2|null|3035|null| 4|null|978298625|\n", 120 | "| 2|null|1253|null| 3|null|978299120|\n", 121 | "| 2|null|1610|null| 5|null|978299809|\n", 122 | "| 2|null| 292|null| 3|null|978300123|\n", 123 | "| 2|null|2236|null| 5|null|978299220|\n", 124 | "| 2|null|3071|null| 4|null|978299120|\n", 125 | "| 2|null| 902|null| 2|null|978298905|\n", 126 | "| 2|null| 368|null| 4|null|978300002|\n", 127 | "| 2|null|1259|null| 5|null|978298841|\n", 128 | "| 2|null|3147|null| 5|null|978298652|\n", 129 | "| 2|null|1544|null| 4|null|978300174|\n", 130 | "| 2|null|1293|null| 5|null|978298261|\n", 131 | "| 2|null|1188|null| 4|null|978299620|\n", 132 | "| 2|null|3255|null| 4|null|978299321|\n", 133 | "| 2|null|3256|null| 2|null|978299839|\n", 134 | "| 2|null|3257|null| 3|null|978300073|\n", 135 | "| 2|null| 110|null| 5|null|978298625|\n", 136 | "| 2|null|2278|null| 3|null|978299889|\n", 137 | "| 2|null|2490|null| 3|null|978299966|\n", 138 | "| 2|null|1834|null| 4|null|978298813|\n", 139 | "| 2|null|3471|null| 5|null|978298814|\n", 140 | "| 2|null| 589|null| 4|null|978299773|\n", 141 | "| 2|null|1690|null| 3|null|978300051|\n", 142 | "+---+----+----+----+---+----+---------+\n", 143 | "only showing top 100 rows\n", 144 | "\n" 145 | ] 146 | } 147 | ], 148 | "source": [ 149 | "df = spark.read.csv(\"data/mldataset/ratings.dat\", sep=':')\n", 150 | "df.show(100,truncate=True)" 151 | ] 152 | }, 153 | { 154 | "cell_type": "code", 155 | "execution_count": 3, 156 | "metadata": {}, 157 | "outputs": [ 158 | { 159 | "name": "stdout", 160 | "output_type": "stream", 161 | "text": [ 162 | "+------+-----+-------+-----+------+-----+---------+\n", 163 | "|UserId|Null1|MovieId|Null2|Rating|Null3|TimeStamp|\n", 164 | "+------+-----+-------+-----+------+-----+---------+\n", 165 | "| 1| null| 1193| null| 5| null|978300760|\n", 166 | "| 1| null| 661| null| 3| null|978302109|\n", 167 | "| 1| null| 914| null| 3| null|978301968|\n", 168 | "| 1| null| 3408| null| 4| null|978300275|\n", 169 | "| 1| null| 2355| null| 5| null|978824291|\n", 170 | "| 1| null| 1197| null| 3| null|978302268|\n", 171 | "| 1| null| 1287| null| 5| null|978302039|\n", 172 | "| 1| null| 2804| null| 5| null|978300719|\n", 173 | "| 1| null| 594| null| 4| null|978302268|\n", 174 | "| 1| null| 919| null| 4| null|978301368|\n", 175 | "| 1| null| 595| null| 5| null|978824268|\n", 176 | "| 1| null| 938| null| 4| null|978301752|\n", 177 | "| 1| null| 2398| null| 4| null|978302281|\n", 178 | "| 1| null| 2918| null| 4| null|978302124|\n", 179 | "| 1| null| 1035| null| 5| null|978301753|\n", 180 | "| 1| null| 2791| null| 4| null|978302188|\n", 181 | "| 1| null| 2687| null| 3| null|978824268|\n", 182 | "| 1| null| 2018| null| 4| null|978301777|\n", 183 | "| 1| null| 3105| null| 5| null|978301713|\n", 184 | "| 1| null| 2797| null| 4| null|978302039|\n", 185 | "+------+-----+-------+-----+------+-----+---------+\n", 186 | "only showing top 20 rows\n", 187 | "\n" 188 | ] 189 | } 190 | ], 191 | "source": [ 192 | "# Let's add headers and drop empty columns\n", 193 | "\n", 194 | "headerd_schema = StructType([\n", 195 | " StructField(\"UserId\", IntegerType()),\n", 196 | " StructField(\"Null1\", StringType()),\n", 197 | " StructField(\"MovieId\", IntegerType()),\n", 198 | " StructField(\"Null2\", StringType()),\n", 199 | " StructField(\"Rating\", IntegerType()),\n", 200 | " StructField(\"Null3\", StringType()),\n", 201 | " StructField(\"TimeStamp\", IntegerType())\n", 202 | "])\n", 203 | "\n", 204 | "df = spark.read.schema(headerd_schema).csv(\"data/mldataset/ratings.dat\", sep=\":\",header=False)\n", 205 | "df.show()" 206 | ] 207 | }, 208 | { 209 | "cell_type": "code", 210 | "execution_count": 4, 211 | "metadata": {}, 212 | "outputs": [], 213 | "source": [ 214 | "df = df.drop('Null1', 'Null2', 'Null3')" 215 | ] 216 | }, 217 | { 218 | "cell_type": "code", 219 | "execution_count": 5, 220 | "metadata": {}, 221 | "outputs": [ 222 | { 223 | "name": "stdout", 224 | "output_type": "stream", 225 | "text": [ 226 | "+------+-------+------+---------+\n", 227 | "|UserId|MovieId|Rating|TimeStamp|\n", 228 | "+------+-------+------+---------+\n", 229 | "| 6040| 858| 4|956703932|\n", 230 | "| 6040| 593| 5|956703954|\n", 231 | "| 6040| 2384| 4|956703954|\n", 232 | "| 6040| 1961| 4|956703977|\n", 233 | "| 6040| 2019| 5|956703977|\n", 234 | "| 6040| 3111| 5|956704056|\n", 235 | "| 6040| 573| 4|956704056|\n", 236 | "| 6040| 3505| 4|956704056|\n", 237 | "| 6040| 213| 5|956704056|\n", 238 | "| 6040| 1419| 3|956704056|\n", 239 | "| 6040| 1734| 2|956704081|\n", 240 | "| 6040| 2503| 5|956704191|\n", 241 | "| 6040| 919| 5|956704191|\n", 242 | "| 6040| 912| 5|956704191|\n", 243 | "| 6040| 527| 5|956704219|\n", 244 | "| 6040| 649| 5|956704257|\n", 245 | "| 6040| 318| 4|956704257|\n", 246 | "| 6040| 1252| 5|956704257|\n", 247 | "| 6040| 3289| 5|956704305|\n", 248 | "| 6040| 759| 5|956704448|\n", 249 | "+------+-------+------+---------+\n", 250 | "only showing top 20 rows\n", 251 | "\n" 252 | ] 253 | } 254 | ], 255 | "source": [ 256 | "df.orderBy('TimeStamp').show()" 257 | ] 258 | }, 259 | { 260 | "cell_type": "code", 261 | "execution_count": 7, 262 | "metadata": {}, 263 | "outputs": [], 264 | "source": [ 265 | "df = pd.read_csv(\"data/mldataset/ratings.dat\", sep=\"::\",header=None, engine='python')\n", 266 | "df.columns = ['UserId','MovieId','Rating','TimeStamp']" 267 | ] 268 | }, 269 | { 270 | "cell_type": "code", 271 | "execution_count": 9, 272 | "metadata": {}, 273 | "outputs": [], 274 | "source": [ 275 | "n_users = len(df['UserId'].unique())\n", 276 | "n_movies = len(df['MovieId'].unique())\n", 277 | "\n", 278 | "n_features = 50" 279 | ] 280 | }, 281 | { 282 | "cell_type": "code", 283 | "execution_count": null, 284 | "metadata": {}, 285 | "outputs": [], 286 | "source": [ 287 | "model = keras.models.Sequential()\n", 288 | "model.add(keras.layers.Embedding(1000, 64, input_length=10))\n", 289 | "model.add(keras.layers.Dropout(0.05))\n", 290 | "model.add(keras.layers.Dense(units = 150, activation = 'relu'))\n", 291 | "model.add(keras.layers.Dropout(0.5))\n", 292 | "model.add(keras.layers.Dense(units = 50, activation = 'softmax'))\n", 293 | "\n", 294 | "\n", 295 | "model.compile(optimizer='adam',\n", 296 | " loss='mse',\n", 297 | " metrics=['accuracy'])\n" 298 | ] 299 | }, 300 | { 301 | "cell_type": "code", 302 | "execution_count": null, 303 | "metadata": {}, 304 | "outputs": [], 305 | "source": [] 306 | } 307 | ], 308 | "metadata": { 309 | "kernelspec": { 310 | "display_name": "Python 3", 311 | "language": "python", 312 | "name": "python3" 313 | }, 314 | "language_info": { 315 | "codemirror_mode": { 316 | "name": "ipython", 317 | "version": 3 318 | }, 319 | "file_extension": ".py", 320 | "mimetype": "text/x-python", 321 | "name": "python", 322 | "nbconvert_exporter": "python", 323 | "pygments_lexer": "ipython3", 324 | "version": "3.7.6" 325 | } 326 | }, 327 | "nbformat": 4, 328 | "nbformat_minor": 4 329 | } 330 | -------------------------------------------------------------------------------- /.ipynb_checkpoints/RecSysKeras-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 3, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import pandas as pd\n", 10 | "import numpy as np\n", 11 | "from sklearn.model_selection import train_test_split\n", 12 | "import scipy.sparse as sp\n", 13 | "\n", 14 | "from keras.layers import Input, Lambda, merge, Dense, Flatten, Embedding, concatenate\n", 15 | "from keras.models import Model, Sequential\n", 16 | "from keras.regularizers import l2\n", 17 | "from keras import backend as K\n", 18 | "from keras.optimizers import SGD,Adam\n", 19 | "from keras.losses import binary_crossentropy\n", 20 | "import numpy.random as rng\n", 21 | "import numpy as np\n", 22 | "import os\n", 23 | "import pickle\n", 24 | "import matplotlib.pyplot as plt\n", 25 | "import seaborn as sns\n", 26 | "from sklearn.utils import shuffle\n", 27 | "\n", 28 | "import preprocessing_data as data\n", 29 | "import metrics" 30 | ] 31 | }, 32 | { 33 | "cell_type": "markdown", 34 | "metadata": {}, 35 | "source": [ 36 | "## Load and transform data\n", 37 | "We're going to load the Movielens $1M$ dataset and create triplets of (user, known positive item, randomly sampled negative item).\n", 38 | "\n", 39 | "The success metric is AUC: in this case, the probability that a randomly chosen known positive item from the test set is ranked higher for a given user than a ranomly chosen negative item." 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": 4, 45 | "metadata": {}, 46 | "outputs": [], 47 | "source": [ 48 | "df = pd.read_csv('data/mldataset/ratings.dat', sep = '::', engine='python', header=None)\n", 49 | "df.columns = ['UserId', 'MovieId', 'Rating', 'Timestamp']\n", 50 | "\n", 51 | "x = np.array(df[['UserId', 'MovieId']])\n", 52 | "y = np.array(df['Rating'])\n", 53 | "\n", 54 | "# Read data\n", 55 | "train, test = data.get_movielens_data()\n", 56 | "num_users, num_items = max(df.UserId) +1, max(df.MovieId) +1\n", 57 | "\n", 58 | "# Prepare the test triplets\n", 59 | "test_uid, test_pid, test_nid = data.get_triplets(test)" 60 | ] 61 | }, 62 | { 63 | "cell_type": "markdown", 64 | "metadata": {}, 65 | "source": [ 66 | "Define a metric between pairs, the _triplet loss function_." 67 | ] 68 | }, 69 | { 70 | "cell_type": "markdown", 71 | "metadata": {}, 72 | "source": [ 73 | "## Neural Network Architecture" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": 3, 79 | "metadata": {}, 80 | "outputs": [], 81 | "source": [ 82 | "def identity_loss(y_true, y_pred): \n", 83 | "\n", 84 | " return K.mean(y_pred - 0 * y_true)\n", 85 | "\n", 86 | "def triplet_loss(inputs, alpha = 0.05):\n", 87 | "\n", 88 | " anchor, positive, negative = inputs\n", 89 | " \n", 90 | " pos_dist = K.sum(K.square(anchor-positive), axis=-1)\n", 91 | " neg_dist = K.sum(K.square(anchor-negative), axis=-1)\n", 92 | " loss = K.sum(K.maximum(pos_dist - neg_dist + alpha, 0), axis=0)\n", 93 | "\n", 94 | " return loss\n", 95 | "\n", 96 | "def bpr_triplet_loss(inputs):\n", 97 | "\n", 98 | " anchor_latent, positive_item_latent, negative_item_latent = inputs\n", 99 | "\n", 100 | " # BPR loss\n", 101 | " loss = 1.0 - K.sigmoid(\n", 102 | " K.sum(anchor_latent * positive_item_latent, axis=-1, keepdims=True) -\n", 103 | " K.sum(anchor_latent * negative_item_latent, axis=-1, keepdims=True))\n", 104 | "\n", 105 | " return loss\n", 106 | "\n", 107 | "def triploss(x): \n", 108 | " res = tf.py_function(bpr_triplet_loss, [x], tf.float32)\n", 109 | " res.set_shape((None, 1))\n", 110 | " return res " 111 | ] 112 | }, 113 | { 114 | "cell_type": "code", 115 | "execution_count": 4, 116 | "metadata": {}, 117 | "outputs": [], 118 | "source": [ 119 | "def getModel(n_users, n_items, emb_dim = 20, margin=1):\n", 120 | " \n", 121 | " # Input Layers\n", 122 | " user_input = Input(shape=[1], name = 'user_input')\n", 123 | " pos_item_input = Input(shape=[1], name = 'pos_item_input')\n", 124 | " neg_item_input = Input(shape=[1], name = 'neg_item_input')\n", 125 | " \n", 126 | " # Embedding Layers\n", 127 | " # Shared embedding layer for positive and negative items\n", 128 | " user_embedding = Embedding(output_dim=emb_dim, input_dim=n_users + 1, input_length=1, name='user_emb')(user_input)\n", 129 | " item_embedding = Embedding(output_dim=emb_dim, input_dim=n_items + 1, input_length=1, name='item_emb')\n", 130 | " \n", 131 | " pos_item_embedding = item_embedding(pos_item_input)\n", 132 | " neg_item_embedding = item_embedding(neg_item_input)\n", 133 | " \n", 134 | " user_vecs = Flatten(name='user_emb_vec')(user_embedding)\n", 135 | " pos_item_vecs = Flatten(name='pos_emb_vec')(pos_item_embedding)\n", 136 | " neg_item_vecs = Flatten(name='neg_emb_vec')(neg_item_embedding)\n", 137 | " \n", 138 | " # Triplet loss function \n", 139 | " AP_loss = Lambda(lambda tensors:K.sum(K.square(tensors[0]*tensors[1]),axis=-1,keepdims=True),name='AP_loss')([user_vecs, pos_item_vecs])\n", 140 | " AN_loss = Lambda(lambda tensors:K.sum(K.square(tensors[0]*tensors[1]),axis=-1,keepdims=True),name='AN_loss')([user_vecs, neg_item_vecs])\n", 141 | " Triplet_loss = Lambda(lambda loss: 1.0 - K.sigmoid(loss[0] - loss[1]),\n", 142 | " name='Triplet_loss')\n", 143 | " \n", 144 | " #call this layer on list of two input tensors.\n", 145 | " Final_loss = Triplet_loss([AP_loss, AN_loss])\n", 146 | "\n", 147 | " model = Model(inputs=[user_input, pos_item_input, neg_item_input],outputs=[Final_loss])\n", 148 | " model.compile(loss=identity_loss, optimizer=Adam(), metrics=['accuracy'])\n", 149 | " \n", 150 | " return model" 151 | ] 152 | }, 153 | { 154 | "cell_type": "code", 155 | "execution_count": 5, 156 | "metadata": {}, 157 | "outputs": [ 158 | { 159 | "name": "stdout", 160 | "output_type": "stream", 161 | "text": [ 162 | "Model: \"model_1\"\n", 163 | "__________________________________________________________________________________________________\n", 164 | "Layer (type) Output Shape Param # Connected to \n", 165 | "==================================================================================================\n", 166 | "user_input (InputLayer) (None, 1) 0 \n", 167 | "__________________________________________________________________________________________________\n", 168 | "pos_item_input (InputLayer) (None, 1) 0 \n", 169 | "__________________________________________________________________________________________________\n", 170 | "neg_item_input (InputLayer) (None, 1) 0 \n", 171 | "__________________________________________________________________________________________________\n", 172 | "user_emb (Embedding) (None, 1, 100) 604200 user_input[0][0] \n", 173 | "__________________________________________________________________________________________________\n", 174 | "item_emb (Embedding) (None, 1, 100) 395400 pos_item_input[0][0] \n", 175 | " neg_item_input[0][0] \n", 176 | "__________________________________________________________________________________________________\n", 177 | "user_emb_vec (Flatten) (None, 100) 0 user_emb[0][0] \n", 178 | "__________________________________________________________________________________________________\n", 179 | "pos_emb_vec (Flatten) (None, 100) 0 item_emb[0][0] \n", 180 | "__________________________________________________________________________________________________\n", 181 | "neg_emb_vec (Flatten) (None, 100) 0 item_emb[1][0] \n", 182 | "__________________________________________________________________________________________________\n", 183 | "AP_loss (Lambda) (None, 1) 0 user_emb_vec[0][0] \n", 184 | " pos_emb_vec[0][0] \n", 185 | "__________________________________________________________________________________________________\n", 186 | "AN_loss (Lambda) (None, 1) 0 user_emb_vec[0][0] \n", 187 | " neg_emb_vec[0][0] \n", 188 | "__________________________________________________________________________________________________\n", 189 | "Triplet_loss (Lambda) (None, 1) 0 AP_loss[0][0] \n", 190 | " AN_loss[0][0] \n", 191 | "==================================================================================================\n", 192 | "Total params: 999,600\n", 193 | "Trainable params: 999,600\n", 194 | "Non-trainable params: 0\n", 195 | "__________________________________________________________________________________________________\n", 196 | "None\n", 197 | "AUC before training 0.4979528166910992\n" 198 | ] 199 | } 200 | ], 201 | "source": [ 202 | "emb_dim = 100\n", 203 | "n_epochs = 20\n", 204 | "\n", 205 | "model = getModel(num_users, num_items, emb_dim)\n", 206 | "\n", 207 | "# Print the model structure\n", 208 | "print(model.summary())\n", 209 | "\n", 210 | "# Sanity check, should be around 0.5\n", 211 | "print('AUC before training %s' % metrics.full_auc(model, test))\n" 212 | ] 213 | }, 214 | { 215 | "cell_type": "code", 216 | "execution_count": 6, 217 | "metadata": {}, 218 | "outputs": [ 219 | { 220 | "name": "stdout", 221 | "output_type": "stream", 222 | "text": [ 223 | "Epoch 0\n" 224 | ] 225 | }, 226 | { 227 | "name": "stderr", 228 | "output_type": "stream", 229 | "text": [ 230 | "/opt/conda/lib/python3.7/site-packages/tensorflow_core/python/framework/indexed_slices.py:424: UserWarning: Converting sparse IndexedSlices to a dense Tensor of unknown shape. This may consume a large amount of memory.\n", 231 | " \"Converting sparse IndexedSlices to a dense Tensor of unknown shape. \"\n" 232 | ] 233 | }, 234 | { 235 | "name": "stdout", 236 | "output_type": "stream", 237 | "text": [ 238 | "AUC 0.4933859503570521\n", 239 | "Epoch 1\n", 240 | "AUC 0.4923985335323258\n", 241 | "Epoch 2\n", 242 | "AUC 0.4926312621200061\n", 243 | "Epoch 3\n", 244 | "AUC 0.4953075512386629\n", 245 | "Epoch 4\n", 246 | "AUC 0.49496981619833613\n", 247 | "Epoch 5\n", 248 | "AUC 0.4951794481094716\n", 249 | "Epoch 6\n", 250 | "AUC 0.495571400337752\n", 251 | "Epoch 7\n", 252 | "AUC 0.4966562351203805\n", 253 | "Epoch 8\n", 254 | "AUC 0.4957264756223382\n", 255 | "Epoch 9\n", 256 | "AUC 0.4967664864742963\n", 257 | "Epoch 10\n", 258 | "AUC 0.4988876970609935\n", 259 | "Epoch 11\n", 260 | "AUC 0.4986669394590502\n", 261 | "Epoch 12\n", 262 | "AUC 0.4983488105028148\n", 263 | "Epoch 13\n", 264 | "AUC 0.4979363490233906\n", 265 | "Epoch 14\n", 266 | "AUC 0.4980200197223901\n", 267 | "Epoch 15\n", 268 | "AUC 0.49754852043498443\n", 269 | "Epoch 16\n", 270 | "AUC 0.49775049834168994\n", 271 | "Epoch 17\n", 272 | "AUC 0.49821718098914103\n", 273 | "Epoch 18\n", 274 | "AUC 0.4987022261658519\n", 275 | "Epoch 19\n", 276 | "AUC 0.49949989949587953\n" 277 | ] 278 | } 279 | ], 280 | "source": [ 281 | "for epoch in range(n_epochs):\n", 282 | "\n", 283 | " print('Epoch %s' % epoch)\n", 284 | "\n", 285 | " # Sample triplets from the training data\n", 286 | " uid, pid, nid = data.get_triplets(train)\n", 287 | "\n", 288 | " X = {\n", 289 | " 'user_input': uid,\n", 290 | " 'pos_item_input': pid,\n", 291 | " 'neg_item_input': nid\n", 292 | " }\n", 293 | "\n", 294 | " model.fit(X,\n", 295 | " np.ones(len(uid)),\n", 296 | " batch_size=64,\n", 297 | " epochs=1,\n", 298 | " verbose=0,\n", 299 | " shuffle=True)\n", 300 | "\n", 301 | " print('AUC %s' % metrics.full_auc(model, test))" 302 | ] 303 | }, 304 | { 305 | "cell_type": "code", 306 | "execution_count": 7, 307 | "metadata": {}, 308 | "outputs": [ 309 | { 310 | "name": "stdout", 311 | "output_type": "stream", 312 | "text": [ 313 | "Epoch 1/20\n", 314 | "49906/49906 [==============================] - 7s 137us/step - loss: 0.1102 - accuracy: 0.1033\n", 315 | "Epoch 2/20\n", 316 | "49906/49906 [==============================] - 7s 133us/step - loss: 0.1020 - accuracy: 0.0945\n", 317 | "Epoch 3/20\n", 318 | "49906/49906 [==============================] - 6s 129us/step - loss: 0.0949 - accuracy: 0.0867\n", 319 | "Epoch 4/20\n", 320 | "49906/49906 [==============================] - 7s 134us/step - loss: 0.0887 - accuracy: 0.0807\n", 321 | "Epoch 5/20\n", 322 | "49906/49906 [==============================] - 6s 124us/step - loss: 0.0830 - accuracy: 0.0758\n", 323 | "Epoch 6/20\n", 324 | "49906/49906 [==============================] - 6s 117us/step - loss: 0.0781 - accuracy: 0.0720\n", 325 | "Epoch 7/20\n", 326 | "49906/49906 [==============================] - 7s 137us/step - loss: 0.0738 - accuracy: 0.0686\n", 327 | "Epoch 8/20\n", 328 | "49906/49906 [==============================] - 6s 123us/step - loss: 0.0699 - accuracy: 0.0653\n", 329 | "Epoch 9/20\n", 330 | "49906/49906 [==============================] - 7s 138us/step - loss: 0.0665 - accuracy: 0.0628\n", 331 | "Epoch 10/20\n", 332 | "49906/49906 [==============================] - 7s 133us/step - loss: 0.0636 - accuracy: 0.0605\n", 333 | "Epoch 11/20\n", 334 | "49906/49906 [==============================] - 6s 113us/step - loss: 0.0611 - accuracy: 0.0586\n", 335 | "Epoch 12/20\n", 336 | "49906/49906 [==============================] - 7s 134us/step - loss: 0.0587 - accuracy: 0.0565\n", 337 | "Epoch 13/20\n", 338 | "49906/49906 [==============================] - 6s 128us/step - loss: 0.0563 - accuracy: 0.0543\n", 339 | "Epoch 14/20\n", 340 | "49906/49906 [==============================] - 6s 126us/step - loss: 0.0542 - accuracy: 0.0527\n", 341 | "Epoch 15/20\n", 342 | "49906/49906 [==============================] - 7s 133us/step - loss: 0.0526 - accuracy: 0.0514\n", 343 | "Epoch 16/20\n", 344 | "49906/49906 [==============================] - 6s 126us/step - loss: 0.0511 - accuracy: 0.0501\n", 345 | "Epoch 17/20\n", 346 | "49906/49906 [==============================] - 6s 117us/step - loss: 0.0496 - accuracy: 0.0486\n", 347 | "Epoch 18/20\n", 348 | "49906/49906 [==============================] - 6s 122us/step - loss: 0.0483 - accuracy: 0.0474\n", 349 | "Epoch 19/20\n", 350 | "49906/49906 [==============================] - 7s 143us/step - loss: 0.0472 - accuracy: 0.0465\n", 351 | "Epoch 20/20\n", 352 | "49906/49906 [==============================] - 8s 151us/step - loss: 0.0461 - accuracy: 0.0454\n" 353 | ] 354 | }, 355 | { 356 | "data": { 357 | "text/plain": [ 358 | "" 359 | ] 360 | }, 361 | "execution_count": 7, 362 | "metadata": {}, 363 | "output_type": "execute_result" 364 | } 365 | ], 366 | "source": [ 367 | "model.fit(X,\n", 368 | " np.ones(len(uid)),\n", 369 | " batch_size=64,\n", 370 | " epochs=20,\n", 371 | " verbose=1,\n", 372 | " shuffle=True)" 373 | ] 374 | }, 375 | { 376 | "cell_type": "code", 377 | "execution_count": 9, 378 | "metadata": {}, 379 | "outputs": [ 380 | { 381 | "data": { 382 | "text/plain": [ 383 | "UserId 4\n", 384 | "MovieId 2951\n", 385 | "Rating 4\n", 386 | "Timestamp 978294282\n", 387 | "Name: 235, dtype: int64" 388 | ] 389 | }, 390 | "execution_count": 9, 391 | "metadata": {}, 392 | "output_type": "execute_result" 393 | } 394 | ], 395 | "source": [ 396 | "df.iloc[235]" 397 | ] 398 | }, 399 | { 400 | "cell_type": "code", 401 | "execution_count": 21, 402 | "metadata": {}, 403 | "outputs": [ 404 | { 405 | "data": { 406 | "text/plain": [ 407 | "-8.945791" 408 | ] 409 | }, 410 | "execution_count": 21, 411 | "metadata": {}, 412 | "output_type": "execute_result" 413 | } 414 | ], 415 | "source": [ 416 | "metrics.predict(model, uid = 1, pids = 5)" 417 | ] 418 | }, 419 | { 420 | "cell_type": "code", 421 | "execution_count": 17, 422 | "metadata": {}, 423 | "outputs": [ 424 | { 425 | "data": { 426 | "text/plain": [ 427 | "array([ 1, 3, 6, ..., 928, 943, 1074], dtype=int32)" 428 | ] 429 | }, 430 | "execution_count": 17, 431 | "metadata": {}, 432 | "output_type": "execute_result" 433 | } 434 | ], 435 | "source": [ 436 | "pid" 437 | ] 438 | }, 439 | { 440 | "cell_type": "code", 441 | "execution_count": null, 442 | "metadata": {}, 443 | "outputs": [], 444 | "source": [] 445 | } 446 | ], 447 | "metadata": { 448 | "kernelspec": { 449 | "display_name": "Python 3", 450 | "language": "python", 451 | "name": "python3" 452 | }, 453 | "language_info": { 454 | "codemirror_mode": { 455 | "name": "ipython", 456 | "version": 3 457 | }, 458 | "file_extension": ".py", 459 | "mimetype": "text/x-python", 460 | "name": "python", 461 | "nbconvert_exporter": "python", 462 | "pygments_lexer": "ipython3", 463 | "version": "3.7.6" 464 | } 465 | }, 466 | "nbformat": 4, 467 | "nbformat_minor": 4 468 | } 469 | -------------------------------------------------------------------------------- /.ipynb_checkpoints/RecSysTripletLoss-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stderr", 10 | "output_type": "stream", 11 | "text": [ 12 | "Using TensorFlow backend.\n" 13 | ] 14 | } 15 | ], 16 | "source": [ 17 | "import os\n", 18 | "import numpy as np\n", 19 | "np.random.seed(0)\n", 20 | "import matplotlib.pyplot as plt\n", 21 | "%matplotlib inline\n", 22 | "from pylab import *\n", 23 | "from keras.models import Sequential\n", 24 | "from keras.optimizers import Adam\n", 25 | "from keras.layers import Conv2D, ZeroPadding2D, Activation, Input, concatenate\n", 26 | "from keras.models import Model\n", 27 | "from keras.datasets import mnist\n", 28 | "\n", 29 | "from keras.layers.normalization import BatchNormalization\n", 30 | "from keras.layers.pooling import MaxPooling2D\n", 31 | "from keras.layers.merge import Concatenate\n", 32 | "from keras.layers.core import Lambda, Flatten, Dense\n", 33 | "from keras.initializers import glorot_uniform,he_uniform\n", 34 | "\n", 35 | "from keras.engine.topology import Layer\n", 36 | "from keras.regularizers import l2\n", 37 | "from keras import backend as K\n", 38 | "from keras.utils import plot_model,normalize\n", 39 | "\n", 40 | "from sklearn.metrics import roc_curve,roc_auc_score" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": 2, 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "def DrawPics(tensor,n=0,template='{}',classnumber=None):\n", 50 | " if (n==0):\n", 51 | " N = tensor.shape[0]\n", 52 | " else:\n", 53 | " N = min(n,tensor.shape[0])\n", 54 | " fig=plt.figure(figsize=(16,2))\n", 55 | " nbligne = floor(N/20)+1\n", 56 | " for m in range(N):\n", 57 | " subplot = fig.add_subplot(nbligne,min(N,20),m+1)\n", 58 | " axis(\"off\")\n", 59 | " plt.imshow(tensor[m,:,:,0],vmin=0, vmax=1,cmap='Greys')\n", 60 | " if (classnumber!=None):\n", 61 | " subplot.title.set_text((template.format(classnumber)))" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": 3, 67 | "metadata": {}, 68 | "outputs": [], 69 | "source": [ 70 | "nb_classes = 10\n", 71 | "img_rows, img_cols = 28, 28\n", 72 | "input_shape = (img_rows, img_cols, 1)\n", 73 | "\n", 74 | "def buildDataSet():\n", 75 | " \"\"\"Build dataset for train and test\n", 76 | " \n", 77 | " \n", 78 | " returns:\n", 79 | " dataset : list of lengh 10 containing images for each classes of shape (?,28,28,1)\n", 80 | " \"\"\"\n", 81 | " (x_train_origin, y_train_origin), (x_test_origin, y_test_origin) = mnist.load_data()\n", 82 | "\n", 83 | " assert K.image_data_format() == 'channels_last'\n", 84 | " x_train_origin = x_train_origin.reshape(x_train_origin.shape[0], img_rows, img_cols, 1)\n", 85 | " x_test_origin = x_test_origin.reshape(x_test_origin.shape[0], img_rows, img_cols, 1)\n", 86 | " \n", 87 | " dataset_train = []\n", 88 | " dataset_test = []\n", 89 | " \n", 90 | " #Sorting images by classes and normalize values 0=>1\n", 91 | " for n in range(nb_classes):\n", 92 | " images_class_n = np.asarray([row for idx,row in enumerate(x_train_origin) if y_train_origin[idx]==n])\n", 93 | " dataset_train.append(images_class_n/255)\n", 94 | " \n", 95 | " images_class_n = np.asarray([row for idx,row in enumerate(x_test_origin) if y_test_origin[idx]==n])\n", 96 | " dataset_test.append(images_class_n/255)\n", 97 | " \n", 98 | " return dataset_train,dataset_test,x_train_origin,y_train_origin,x_test_origin,y_test_origin" 99 | ] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "execution_count": 4, 104 | "metadata": {}, 105 | "outputs": [ 106 | { 107 | "name": "stdout", 108 | "output_type": "stream", 109 | "text": [ 110 | "Checking shapes for class 0 (train) : (5923, 28, 28, 1)\n", 111 | "Checking shapes for class 0 (test) : (980, 28, 28, 1)\n", 112 | "Checking first samples\n" 113 | ] 114 | }, 115 | { 116 | "data": { 117 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAA10AAACLCAYAAACa9PPwAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8li6FKAAAUgUlEQVR4nO3deYxU1bbH8b1lEJopNAIBAZugiBqe0I/EGBkUMZeohBnhgYAMPsFWEAdE0i+KoBBFaCQoNpHxPmQSCJiHAoqiggGEDioBjTKLjM0gM5z3B9DptZCqOl1nd1Wd/n6Sm9SPqnNq92VRXcs6q7b1PM8AAAAAANy4KdELAAAAAIAwo+kCAAAAAIdougAAAADAIZouAAAAAHCIpgsAAAAAHKLpAgAAAACHSlzTZa0tZa09Za2tl+i1AC5Q4wg7ahxhR40jzEpqfSd903X1L+Xa/y5ba88Uyj39ns/zvEue51X0PG93EdeTaa390Vp72lq7wVr7H0U5D3ANNY6wo8YRdtQ4woz6DkbSN11X/1Iqep5X0Riz2xjTrtCf/Vs/3lpb2tVarLU3G2OWGmOmG2OqGmPmGmOWWGvLuHpOhB81jrCjxhF21DjCjPoORtI3XdFYa0dba+dZa+daa08aY3pZa++31q631uZba/+01k669pdhrS1trfWstRlX85yr9/+ftfaktXadtbb+DZ7uYWOM53ne+57nnTPGTDDG3GyMaeX+J0VJRY0j7KhxhB01jjCjvmOT8k3XVR2NMf9rjKlijJlnjLlojBlijLnFGPOAMaatMea/Ixz/X8aYbGNMurnSwb95g8fdY4zJuxY8z/OMMVuv/jngEjWOsKPGEXbUOMKM+o4iLE3Xt57nLfM877LneWc8z9vged4Pnudd9Dzvd2PMRyZyB7zQ87yNnuddMMb82xjT5AaPq2iMOa7+7LgxplLcPwEQGTWOsKPGEXbUOMKM+o7C2TWXxWxP4WCtbWSMGW+M+U9jTJq58nP+EOH4A4VunzZX/kL/ySljTGX1Z5WNMSf9LBYoAmocYUeNI+yocYQZ9R1FWD7p8lSeaoz5yRhzu+d5lY0x/2OMsQE8z8/GmHuvBWutNcY0vvrngEvUOMKOGkfYUeMIM+o7irA0XVolc+Wjxr+ttXeZyNeQ+vGlMaaUtfbZq9+eMsQYc8EY83VA5wdiRY0j7KhxhB01jjCjvpWwNl0vGmP6mCsfNU41Vwb64uZ53lljTHtjzABjTL4xppcxpv3V60+B4kSNI+yocYQdNY4wo74Ve+VLPwAAAAAALoT1ky4AAAAASAo0XQAAAADgEE0XAAAAADhE0wUAAAAADtF0AQAAAIBDpaPcz1cbpp4gNp4rSajx1EON+0ONpx5qPHbUd+qhvv2hxlPPP9Y4n3QBAAAAgEM0XQAAAADgEE0XAAAAADhE0wUAAAAADtF0AQAAAIBDNF0AAAAA4BBNFwAAAAA4RNMFAAAAAA7RdAEAAACAQzRdAAAAAOAQTRcAAAAAOETTBQAAAAAO0XQBAAAAgEM0XQAAAADgEE0XAAAAADhUOtELSAZ79uwROScnR+QJEyaI/MILL4g8ZMgQkevWrRvg6gD3tm/fLvLdd98t8uXLlyM+vmHDhm4WBgAweXl5ImdmZopcr149kZcsWSLyHXfcIXJaWlqAqwMQCz7pAgAAAACHaLoAAAAAwCHreV6k+yPemar27dsn8r333ityfn6+r/NVrVpV5EOHDhVtYcGwiXzyFBTKGo9m9erVIo8aNUrk77//XmR9eWHz5s1FzsrKErlz584i33RToP99hxr3J5Q1vmvXLpEzMjJE9ltzH374ocgDBw4s0roCQo3HLpT1renLC5s1a+br+D59+og8bdq0uNcUB+rbn9DU+P79+wtub9y4UdzXoUOHuM6t+5n09HSRt23bJnKNGjXier4o/rHG+aQLAAAAAByi6QIAAAAAh2i6AAAAAMChEvOV8YWv/3/wwQfFfceOHRPZWnkpZpUqVUS++eabRT548KDIv//+u8i33XabyKVKlYq+YCBAeoZLz6/oGa5o9ON1Pnz4sMj63xAQr+zsbJH1DJffma7BgweLfODAAZG7d+8usv4KbiBeZ8+eLbj98ssvi/uWLVsW17nvv//+uI4HYnHkyBGRP/vsM5HHjBlTcPvXX38V9+n33n7p4/X3M7Rr107k+fPni6zfq7vAJ10AAAAA4BBNFwAAAAA4RNMFAAAAAA6FZp+uCxcuiKz3cGnbtm3B7Z07d4r79P8H+rrQVq1aiVz4mlRjrt+zSJ/vo48+Erl///7GIfa/8CdlalwrfP2/MbLm9bXLf/75Z8RjtczMTJEvXbokst4zRnM800WN+5OyNa6vyW/ZsmXB7T179oj7Tpw4IXK8e8PpvekWLVokcrx7ykRBjccuZetbKzzj0r59e3Hf9u3bRY63vtetWyey332/4kR9+5O0NX7u3DmR9XsPPU9eWLT33q7NmjVL5J49ewZ5evbpAgAAAIDiRtMFAAAAAA7RdAEAAACAQ6HZp0vvaTF58uTAzv3111+L/Pfff4vcsWNHkT/99FORN2/eHNhaUHJt2LBB5EmTJon8ySefFNzW8yh+r/8fO3asyPp8hWckAVf0LOG2bdsStBLAvePHjxfc1jOKQdMzY3ofMD3Xi5JJz3+PHDlS5Dlz5ois57mDVK5cOZFvueUWkffu3evsuYPCJ10AAAAA4BBNFwAAAAA4RNMFAAAAAA6l7EyX3qNFX1caaf8xPYPVuXNnkXv16iVy3bp1Rb7rrrtEHj58uMgLFy6MeS3AjehZwtatW8d8rJ7B8itazcZ7fiAWr7/+emDnWrFihchr164VWe+/CATtzTffFNlPfQf9mrt//36R9f6lzHTBmOvfh0ycODFBK7n+vbd+zX700UeLczlFwiddAAAAAOAQTRcAAAAAOETTBQAAAAAOpcxM1759+0Ru2rSpyPn5+SJba0Xu2bNnwe3c3Fxx3y+//CKyvr979+4ip6WliVy7dm2R9Z5Is2fPFvnVV18VWc+MoWTS105369ZNZF1X5cuXF7lOnToFt/W/h0OHDkV8bn2uChUqiHzq1KmIawGKIi8vT+R45khycnJEzsrKivj4Y8eOiaxnZqJlwC/9viSe19FnnnlG5EceeUTkL774QmT9vkabO3euyHovRv2+B+Gkf9dPmTIl0PMXngnTM1p6j9upU6eK/O6774p87ty5QNdWHHjnBAAAAAAO0XQBAAAAgEM0XQAAAADgUNLOdB0+fFjkcePGiayvx69Zs6bI9evXF3nQoEEFt8uWLSvua9KkScQcr9OnT4v8zjvviDxp0qRAnw+pYcOGDSLrfbiiXe+vr7mfP39+we3Vq1dHfKw2ffp0ke+77z6R9fkAF+KZcYk2w6X5na9hjhHRnDx5UuQtW7aI/P7770c8vlq1agW3a9WqJe5r2bKlyPo9UZkyZUTWc/DRLFmyROSzZ8+KzExXyaD3uvruu+8iPl6/LlavXl3k7OxskQcMGFBwW9fsAw88IPKoUaNETk9PF/n8+fM3PLcxxkybNu1Gy04YfosAAAAAgEM0XQAAAADgEE0XAAAAADiUNDNdFy9eFPmll14Sec6cOSJXqVJF5M8//1zk22+/XeQLFy7Eu8TA/PHHH4leAhJAz0V17do14uP13ll6LivafEBhzZs3F1nPv3Ts2DHi8S1atBBZ7wmzcuXKmNcCXDNs2LAiH1uvXj1fj7906ZLIBw4cKPJzA8YYs3//fpH16+jGjRtFjjYXOGTIkILbI0aM8LWWI0eOiDx06FBfx6Nk2rlzp8g//fSTr+P1DJf+N+GHfs+js6b36Tp+/HiRn7u48EkXAAAAADhE0wUAAAAADtF0AQAAAIBDSTPTtXv3bpH1DJe2fv16kRs2bBjx8dGuDQVc03NUek8XbeLEiSL369cv5udq2rSpyMuXLxe5QoUKMZ/LmOv3tmPPFgRB78vyzTffxHys3lcomnnz5onMzAvitW7dOpF//PHHBK3EmEqVKok8cuRIkceMGePrfG+88YbIOTk5RVsYkpreC8vvXJTeh6s46T5gwYIFCVpJ7PikCwAAAAAcoukCAAAAAIeS5vLCZ599VmTP80TWX8Ua7XLCRLp8+bLI+mti9c+GcNq3b5/I+fn5Ius60V9pHY/09PTAzvVPdA3rnwWIhb7kKdpXavft27fgdoMGDXw919ixY309HtDOnz8vcl5ensj6dVDnZs2aiay3EalYsWK8S4x5LdHW5vdyRJQMeuxhwIABCVpJauKTLgAAAABwiKYLAAAAAByi6QIAAAAAhxI207V582aR9VcFW2tF7tq1q/M1BUXPJeifRV87jXA4cOCAyG3atBH58OHDIkebX0kmepbhzJkzIqfSz4LE6dWrl8jRZgEbN24scm5ubpGf2+8cYo8ePUTu0KFDkZ8b4TBhwgSR3377bZGjvQ6OGDFC5CBnuPQWJNHWpt+HLF682NnakDxmzpwp8owZMyI+vlatWiK3aNFC5DJlygSyrqL417/+JfKwYcNEHj9+fMTjEzGLzjslAAAAAHCIpgsAAAAAHKLpAgAAAACHEjbTdfbsWZHPnTsncu3atUV+7LHHnK8pVhcvXhR50qRJER/fpUsXkV977bXA14TEy8rKEvm3335L0EqCt3btWpFXrlyZoJUglezYsUPkTZs2iaznTKLNw/qxfv16kQ8ePBjxuTQ9HwBs2LAhruP3798vst6bsVSpUnGd34+FCxeKrN9zIZz0a2q019jevXuL3KRJk8DXFBS/vz8SMYvOJ10AAAAA4BBNFwAAAAA4RNMFAAAAAA4lbKYrmnLlyomcyD0j9AzXBx98IPIrr7wickZGhsgjR44UuWzZssEtDilr9uzZiV7CDR06dEhkvb+S1rBhQ5FLl07alxYUo61bt4rscs5R7yW3YMECkY8ePRrx+BUrVoicmZkZzMIQGp06dRJ56dKlvo5fs2aNyH379hU5LS3thsfm5+eLvHr1apH1PpGafq68vDyR69atG/F4APHjky4AAAAAcIimCwAAAAAcoukCAAAAAIeSdvDiySefTNhz79u3T+Rx48aJPGXKFJGfeuopkXNzc90sDKFSs2bNRC+hgJ7hevjhh0XWexzVqlVLZL1vV4UKFQJcHUqqnj17xvzY0aNHixxt/0StRo0avh6Pkmfs2LG+Ht+6dWuR9XuDSDNc2q5du0Tu3r27r7U8//zzIjdu3NjX8UhdhfeHGzNmTAJXEh/9/Qrz5s0TefLkyRGP179PunXrFszCfOCTLgAAAABwiKYLAAAAAByi6QIAAAAAhxI20+V5XsQ8Y8YMkbOzs52tZe7cuSI/99xzIh87dkxkfW30hAkT3CwMKUXX8OXLlyM+vk2bNiJfunQp8DVdo/cwGjx4sMjTp0+PeHyjRo1E1vvNVK9eveiLA26gX79+N7xPX7+vZxVuuinyf1PU+xY1aNDA3+JQ4rRq1Urkn3/+OeLjV61aJfJ7770nst73a9GiRQW3dT3r3yfR6nvUqFEiM8NVctWuXbvgtt43Vn8nQTLTM1y9e/f2dbzeI7dMmTJxr8kvPukCAAAAAIdougAAAADAIZouAAAAAHAoYTNd1tqIee/evSLr65P79+8vcqVKlUTW11pPnTq14PbatWvFfTt37hRZX9uv98PQM12AMdfvE/Ttt9+KfPTo0YjHd+7cWWT9b+KJJ54Q+c477xS58L8RPV925swZkaPtqzVx4kSR27VrJzIzXIhFtDnHaHOPmzZtErnwa3e0mRctKytL5JycnIiPB7Q6deqIHG2uSnvrrbciZj/n1vfrvRMTsQcREKSFCxeK/PTTT/s6vkqVKiK/+OKLca8pXnzSBQAAAAAO0XQBAAAAgEM0XQAAAADgkNXX3CsR74zHunXrRG7RooWv42+99VaR09PTRd66dWvM52rbtm3ErGcBkpyN/hAU4qzGd+zYIbKucT3j5Xcflkiinatjx44iDxo0SOSHHnqoyM9dDKhxf5zVeDSF9x0y5vr52OKs+b/++ktk/TsjyVDjsSu2+u7SpYvIS5cujfj4IOtbz93Wr19f5MWLF4uckZFR5OcqBtS3P4HV+KxZs0SOtk9X1apVRa5bt67Iy5YtE9nP6+ru3btF7tGjh8j6+xZOnDghcvny5UWuXLmyyF999ZXIeg7esX+scT7pAgAAAACHaLoAAAAAwCGaLgAAAABwKGEzXfraTL2nxKpVqyIer9et9zTSatSoUXBbz69kZ2dHPDbFcK20P8U2D3Dy5EmR9bXVev+3eK7/r127tsiPP/64yOPHjxe5XLlyRX6uBKDG/UnYTNeaNWtE1rOE+vdAPDV/zz33iDxixAiR9TxOqVKlivxcxYAaj12x1ffp06dF7tOnj8hLliwROZ6Zrr59+4rcvn17kfVreoqhvv0JrMZ/+OEHkfV7b71Hrkt+38frfbc+/vhjkTt06BDMwoLBTBcAAAAAFDeaLgAAAABwiKYLAAAAABxK2EyXdurUKZGjzbtEuxZ09OjRIg8cOLDgdrVq1Yq8zhTAtdL+JGzeRfvyyy9FHj58uMhbtmwROTMzU+SxY8cW3K5Xr564r0GDBkEsMVlQ4/4kTY0vX75cZD2nEs9M14ULF4p8bBKixmOXsPo+fvy4yF27dhV59erVIuv6XrFihciFZ8/1a3ZaWlqR15mEqG9/nNV4v379RJ45c6arp7qO35muBQsWiNypU6fA1xQgZroAAAAAoLjRdAEAAACAQ0lzeSECw8f2/lDjqYca9ydpazw3N1fkwYMHi6wvk9VfyV1Y48aNg1tY4lHjsUva+sYNUd/+OKvxQ4cOidyoUSOR8/PzXT31dZcXDh06VOSsrCyRMzIyRI7ncvRiwOWFAAAAAFDcaLoAAAAAwCGaLgAAAABwiJmu8OFaaX+o8dRDjftDjaceajx21Hfqob79ocZTDzNdAAAAAFDcaLoAAAAAwCGaLgAAAABwiKYLAAAAAByi6QIAAAAAh2i6AAAAAMAhmi4AAAAAcIimCwAAAAAcoukCAAAAAIdougAAAADAIZouAAAAAHDIep6X6DUAAAAAQGjxSRcAAAAAOETTBQAAAAAO0XQBAAAAgEM0XQAAAADgEE0XAAAAADhE0wUAAAAADv0/19iupGJFddUAAAAASUVORK5CYII=\n", 118 | "text/plain": [ 119 | "
" 120 | ] 121 | }, 122 | "metadata": { 123 | "needs_background": "light" 124 | }, 125 | "output_type": "display_data" 126 | }, 127 | { 128 | "data": { 129 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAA10AAACLCAYAAACa9PPwAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8li6FKAAAVZklEQVR4nO3de7DV87/H8fenC8nuplNK+qHblNQhRdEcl8qMoYvogjQuUWPINEShmwklyj3RT1J7ROkI81M6SjmTLkLmILl2dNLWXbdd0ff8Udbs99vea6/vXuvzXd+1PB8zzaxX37W+3096r9V6W+u9Py4IAgEAAAAA+FEp2wsAAAAAgHxG0wUAAAAAHtF0AQAAAIBHNF0AAAAA4BFNFwAAAAB4RNMFAAAAAB7RdAEAAACARznVdDnn9pb4dcQ5d6BEvj6N8650zg0o5z4dnHOfO+f2O+dWO+fOquj1gLJQ48hn1DfyHTWOfEeNV1xONV1BEBT8+UtE/ldEupf4vUJf13XOnSAiC0TkRRGpIyJzReQ/nXNVfF0Tf0/UOPIZ9Y18R40j31HjFZdTTVd5nHOVnXOjnHM/OOe2OecKnXO1jx070Tk3xzm3wzm3yzm3yjlXxzn3hIh0EJHpx7r0J0o5dTcRKQ6C4PkgCA6KyBMiUkNEOkf2hwOEGkd+o76R76hx5DtqvGx51XSJyHARuUyO/gWcKiKHRWTKsWODRKSKiDQSkX8TkTtE5FAQBHeLyBoRGXSsS7+7lPO2FpF1f4YgCI6IyP8c+30gStQ48hn1jXxHjSPfUeNlyLema7CIjAiCYHMQBMUiMk5E+jnnnBz9S68nIk2DIPg9CII1QRDsS/G8BSKy2/zebjnaYQNRosaRz6hv5DtqHPmOGi9DznwPsjzH/jIbi8i/nHNBiUOVRKSuiPxTRBqIyDznXIGIvCoio4Ig+COF0+8VkZrm92qKyJ60Fw6kiBpHPqO+ke+oceQ7ajy5vPmkKwiCQET+T0QuDYKgdolf1YIg2BYEwcEgCEYHQdBSRP5DRPqISP8/H17O6b8UkX//MzjnKonIWcd+H4gENY58Rn0j31HjyHfUeHJ503Qd84KITHDONRYRcc7Vd851P3a7q3PuzGN/Sb+JyO8i8mdnXSQiTZKcd7GInOCcG+KcO15EhonIPhH5b09/DqAs1DjyGfWNfEeNI99R42XIt6brMRH5LxFZ4pzbIyIrRKTdsWON5OiPmtwjRwfv/iUibxw7NkVEBjrndjrnHrMnDYLggIj0FJEhIrJLjnblvYIg+N3jnwUoDTWOfEZ9I99R48h31HgZ3NFPAgEAAAAAPuTbJ10AAAAAECs0XQAAAADgEU0XAAAAAHhE0wUAAAAAHtF0AQAAAIBHVco5zo82zD0u2wvIMdR47qHGw6HGcw81njrqO/dQ3+FQ47mn1Brnky4AAAAA8IimCwAAAAA8oukCAAAAAI9ougAAAADAI5ouAAAAAPCIpgsAAAAAPKLpAgAAAACPaLoAAAAAwCOaLgAAAADwiKYLAAAAADyi6QIAAAAAj2i6AAAAAMAjmi4AAAAA8IimCwAAAAA8oukCAAAAAI+qZHsBvhw6dEjl8ePHJ24//PDD6tjFF1+s8vz581WuVatWZhcHVMC3336rctOmTVXeunVr4vaiRYvUMVvTvXv3TnqtTp06qdy8efOU1wkACO/IkSOJ22vXrlXHHnroIZXvv/9+latUCfd2rm3btioff/zxoR4PIDw+6QIAAAAAj2i6AAAAAMAjFwRBsuNJD8bZ9u3bVW7QoEGZ9y35kb6IyJtvvqlyr169Mrcw/1y2F5BjYlPjBw8eVPm2225T2dbliSeeWObjf/vtt7TWUr16dZULCgpUfuutt1Tu2LFjWtcLiRoPJzY1HlZxcbHKP/30U+L2O++8o47de++9KleqpP+f4uDBg1U+44wzVL799ttVts+viFHjqcvZ+rYOHz6cuF2tWjWv1xo5cqTKJUcwIkB9hxObGt+1a5fKderUUdk5/Vdrewx7PJl77rlH5X/84x8qn3322SpfeOGFFb6WB6VenE+6AAAAAMAjmi4AAAAA8IimCwAAAAA8ypuZrv3796t8zTXXqLx48eIyH8tM199abGp81KhRKtutDcrTrl27xO3GjRurY+Vte2CfA7Nnz056f3u+L7/8UuVTTjkl6ePTRI2HE5sat/744w+VZ82apfLYsWNV3rRpU5nnSmd2QERk2LBhKk+aNCnU4zOMGk9dbOrbblWzceNGle2PdbdzhVHOdNmZx9q1a6u8atUqlZs0aZLJy1Pf4cSmxu1MV926dbO0kr+aO3euyldddZXKEc94MdMFAAAAAFGj6QIAAAAAj2i6AAAAAMCjnJ3pmjdvnspz5sxRecGCBSmfy86z2P1cunTponLbtm1Vbt68ecrXigDflQ4nazW+ZcsWle2eE7/++qvKp59+usrvvfeeyg0bNkzctvMAxx13XNK12NeBqVOnqjx06FCV7SzOoEGDVH7mmWdUzvB8AjUeTmxfx+fPn69y3759K3wuO3tr95IL6/fff0/r8WmixlMXm/qeMGGCyg888IDKNWvWVHnhwoUqd+jQIXHb7lH01FNPZWKJKfvmm29UbtasWSZPT32HE5sat/uJTp8+XeUxY8aovHPnTu9rKsuOHTtULm+2PcOY6QIAAACAqNF0AQAAAIBHNF0AAAAA4FHOznRVrlxZZbvnRBh2pqu8c9kZrkWLFqls90iKGN+VDidrNf7jjz+q3LRpU5XtnhLZ3D9uypQpKt93330q2/mXTz/9VGU7r5Ymajyc2LyO2znGknvLifx1jjGZmTNnqty/f3+V7QzMvffem/K5RZjpyiGR1bfdh2vy5Mkqjxs3Lun9rdGjR6tcch5m3bp16tiQIUNUXrNmjcrlvJcLzT437fXSRH2HE5vX8PLY9zV2FrDkfPktt9yijtmf1bB169a01mLnert3757W+UJipgsAAAAAokbTBQAAAAAe0XQBAAAAgEc5M9M1YMAAlQsLC1VOZ6arfv36Ktu9NL777rtQ57N7GEWM70qHk7Ua37Bhg8otW7ZU+e6771Z50qRJ3teUqlatWqls93Sxe8w89thjmbw8NR5ObPaiGzFihMqzZs1S2c4x2vnZpUuXJm6ffPLJSR9rX4c3bdqk8gUXXKByUVGRyu3bt1d55cqVEiFqPHWR1bed2XrooYfSOl+3bt1Ufu211xK369Spk/Sx9t+HJ598UuXx48erbOv33XffTXr+6tWrq7xs2bLEbTvvVQHUdzixeS9u2ddZO6dlX+OvvfbaxO3Zs2erY3v27FHZ7gFm972ze4ZZdk/djz/+WOUM7x9qMdMFAAAAAFGj6QIAAAAAj2i6AAAAAMCj2M502XmXnj17Jj0eZqbrwQcfVNn+7P4aNWqovHjxYpXvuuuupOe3+yn16NEj5bVlAN+VDidrNW5r+p133lH57bffVvnKK6/0vqZU2f1l7PyA/c7/J598ksnLU+PhZK3GP/zwQ5W7dOmist0jseQeLiIiM2bMUPm6667L2NqeeOIJle28md2nyx5/9NFHM7aWUlDjqctYfR8+fFjlp59+WuWRI0eqnOn57S+++CJxu3Xr1knvW1xcrLLd465Ro0Yq79+/X+XevXurvGTJkqTXKzmnO3HixKT3TQH1HU5sZrpszU+dOlVl+/7Y7tNV8r2Afa9dHrsX49y5c0M9fvfu3SoXFBSEenxIzHQBAAAAQNRougAAAADAI5ouAAAAAPAoNjNdu3btUvmss85S2e6hYmcB7EyX3d/l5ptvTty23zmtWrVq0rXZ74G2adNG5V9++UXlE044QeUXX3xR5T59+qhcuXLlpNcPie9KhxNZje/cuVPlCy+8UGVb42vWrFG5SZMmfhZWAXbPF7vnETNdsZK1eYA777xTZfv9f/vvT9++fVUuuW+Rb506dVJ59erVKtvn6/Lly30uhxpPXcbq277mduzYMVOnLtV5552n8oIFCxK37f6hmbZjxw6V69Wrl/T+l19+eeK23V+pdu3aYS9PfYcTm5ku+169bt26Se9v3//afbzC+Prrr1W2eynaOUdr/vz5Ktu5+gxjpgsAAAAAokbTBQAAAAAe0XQBAAAAgEdVsr2AP9mf/W/nW8pz1VVXqfzKK6+oXL169QqtS0SkVq1aKk+ZMkVlu3fAvn37VL7hhhtUvuyyy1Q+6aSTKrw25I7XX39d5fXr16t86623qhynGS4gFQcOHFD5gw8+CPX4IUOGZHI5odxxxx0qDxw4MEsrQbbYfboyzc5plZzhKu24T2H/rO+9917i9vfff6+OnXvuuRlZE+LP7idqdevWTeV+/fpl7NqtWrVS2e5dOm/evKSP/+yzz1T2PNNVKj7pAgAAAACPaLoAAAAAwCOaLgAAAADwKDYzXWFdeumlKr/00ksqpzPDVZ6uXbuqfMkll6gcdo4Bfw92v4o6deqoPHz48CiXA2ScnfXYsGFD0vv36NFD5fPPPz/ja8qULVu2qLx3716VCwoKolwOPOjcubPKzmV2Oym7R2eUM1yWfe4Bpdm8ebPKY8eOVdnuS2t/5oHP10W7D2R5M11xwCddAAAAAOARTRcAAAAAeBTbrxceOXIk6fHFixdHtJK/CoJAZfvj7stb+7hx41R+6qmnMrMw5JQOHTqo3KxZsyytBMiMVatWhbr/xIkTVa5WrVoml5NRP/zwg8obN25UuXXr1lEuBznglltuUdnWOxB306ZNU/mnn35SuW/fvirbH+sOjU+6AAAAAMAjmi4AAAAA8IimCwAAAAA8is1M1/Tp01WuVCm+/aD9kfDLly9X2a7d5jFjxvhZGGLl0KFDKh88eDBLKwGisW/fPpXt/KvVokULn8tJi53NjfO/SYgHO6f7wgsvqOyzhoqKilQ+77zzVN62bZvKYf89GjZsWOL2OeecE3J1yBXFxcUqz5w5M+n9Bw8e7HM5eYd/RQAAAADAI5ouAAAAAPCIpgsAAAAAPIrNTFdhYWG2l5Cwf/9+lTdt2qTyXXfdFep8DRs2VLly5coVWxhyyrJly1T++uuvVW7cuHGUy8mo119/PenxqlWrRrQSxMnKlStVds5laSXps/M3ufxnQTRsjWRyhmvRokUqr1ixIulx+74lrBo1aqg8aNCgxG3mG/OXncP9+eefs7SS/MQzBwAAAAA8oukCAAAAAI9ougAAAADAo9jMdMXJ5MmTVR43blyox9u9Z95++22Va9WqVbGFAVliv9c9a9aspPefOnWqz+UAkatZs2bSjNxn51nCzvHt3btX5Y0bN4Z6/KhRoxK37f6f27dvV9nOnmfakiVLVG7ZsqXX6wG+tWnTJttL4JMuAAAAAPCJpgsAAAAAPKLpAgAAAACPmOkSkQEDBqi8du3atM7XoUMHlZs3b57W+YCo2RmuSZMmqbxjxw6Vr7jiCpXbtm3rZ2GAJ88880zS43ZOMZf32UPp7B6c5dWE9dVXX6ncpEmTtNcUle7du6vcrFmzLK0EueSzzz5T+eKLL87OQkph5xDt+5Rs4JMuAAAAAPCIpgsAAAAAPKLpAgAAAACPYjPTZffHOHLkSNL7r1u3Lunxnj17qmxnVJJdq1Kl9HrRV199Na3HIz/YmY84789mnwMTJ05U+fnnn1f5tNNOU9nOPqT7HEJusrN/y5YtU7moqEjl4cOHJ318lDZv3qxygwYNVL766qujXA6y4Prrr1c57ExXnNWvX1/l888/X+XZs2erXFBQ4H1NyH1TpkxR+aabblK5du3a3q796aefJj3eo0cPlatVq+ZtLaninREAAAAAeETTBQAAAAAe0XQBAAAAgEfOzlIZSQ9m0pw5c1S+4YYbkt4/k3NY6Z7rwQcfVHnMmDEVXksGuGxePAdFVuPt2rVT2dbZ8uXLVa5evXrGrm3nVeyM1scff6zy0qVLk55v/fr1Krdo0SKN1YVGjYcTWY1bH374ocpdunRR2da43efI515YI0aMUNnOkw0ZMkTl5557zttaSkGNpy5j9b1lyxaVe/XqpfKaNWsydSnvTj31VJUXLlyocqtWraJcjkV9hxPZa/iBAwdUDjvbd/LJJ6v8yCOPqHzjjTcmbn/++efqmH3+WZMnT1b5o48+UtnuiWvfU/mcLytFqTXOJ10AAAAA4BFNFwAAAAB4RNMFAAAAAB7FZqZr9+7dKrdp00blX375RWWfM12NGjVS2e5nMW3aNJVr1KihctWqVSu8lgzgu9LhZG2my36f+aKLLlI5k/Ms77//vsq//vpr0vvb72UPHDhQ5fHjx6sccc1T4+FkbaZr586dKo8dO1blZ599VuXCwkKV+/fvn7G12DlE+3yzNbx27VqV7XPCM2o8dd7qe9u2bSq3b99e5WT7f/p23HHHqWz3gbTzlC1btvS9pDCo73Aiew23PYHdS9Huy1WeKlX0dsAlX0ft8+vgwYOhzm298cYbKmd5b0VmugAAAAAgajRdAAAAAOARTRcAAAAAeBSbmS7r22+/VXnevHkq272xMjnT9eabb6ps9+qIOb4rHU5kNb569WqV7Xel7Z4TPtnnS7169VR+9NFHVS65t0YMUOPhZO113Nq1a5fKdl5269atKo8aNUrlYcOGlXluOx9g9/zq169f0mtNmDBB5XvuuafMa0WAGk9dZPVta6Z79+4q+9zHq3fv3knztdde6+3aHlDf4WTtNXzJkiUqd+vWLUsr+asZM2aoPGDAAJXT6QsygJkuAAAAAIgaTRcAAAAAeETTBQAAAAAexXamqzxffPGFyk8//bTKM2fOVLnkTMrQoUPVMfvf4LTTTlPZ7n8Rc3xXOpys1fiePXtUtt+VtjNg6Rg5cqTKHTt2VNnOJsQcNR5ObF/H9+7dq/Kdd96p8oIFC1Q+88wzE7dHjBihjt12220ql7cXnd177vHHH1f5pJNOSvp4z6jx1GWtvrds2aLyihUrVO7Tp0/Sx9sZFPvepKQWLVqobPcHzTHUdzhZq3H7/tjONTZs2DCytbz88ssq29dw52JVVsx0AQAAAEDUaLoAAAAAwKOc/XohyhSrz1dzADWee6jxcHKmxouLi1UuKipSefTo0YnbhYWF6pj9ke+W/apX48aNVc7yjxe2qPHU5Ux9I4H6Dic2NW5fo+3XDTdu3Kiy/cps586dE7evuOIKdaxr165Jr21fo2P2dUKLrxcCAAAAQNRougAAAADAI5ouAAAAAPCIma78E+svucYQNZ57qPFwqPHcQ42njvrOPdR3ONR47mGmCwAAAACiRtMFAAAAAB7RdAEAAACARzRdAAAAAOARTRcAAAAAeETTBQAAAAAe0XQBAAAAgEc0XQAAAADgEU0XAAAAAHhE0wUAAAAAHtF0AQAAAIBHLgiCbK8BAAAAAPIWn3QBAAAAgEc0XQAAAADgEU0XAAAAAHhE0wUAAAAAHtF0AQAAAIBHNF0AAAAA4NH/AxBPOLFE6FjeAAAAAElFTkSuQmCC\n", 130 | "text/plain": [ 131 | "
" 132 | ] 133 | }, 134 | "metadata": { 135 | "needs_background": "light" 136 | }, 137 | "output_type": "display_data" 138 | }, 139 | { 140 | "data": { 141 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAA10AAACLCAYAAACa9PPwAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8li6FKAAANMElEQVR4nO3de6hVZZ8H8OfxlnnpdJkwNCKjQilDTKx4/xik4EhlLxn+M/OCSdJUSFNZWeCNnG5YBFHQZCMq0wyC9U/UgNGr0OBUViqEUuREXpgIsaPZxbSz5o+cM/tZ5T7n6H72PnudzweC9W2ttf2JP4/7x7OfvWJRFAEAAIA8hrS6AAAAgCozdAEAAGRk6AIAAMjI0AUAAJCRoQsAACAjQxcAAEBGg27oijEOjTEejTFe0upaIAc9TtXpcapOj1Nlg7W/B/zQdfIP5f/+644x/lST/76/r1cUxa9FUYwpimLvadbzLzHGL07W8pfTeQ2opcepOj1O1elxqkx/N8aAH7pO/qGMKYpiTAhhbwhhds3/e718fYxxWOaStocQ7gkh7Mz86zBI6HGqTo9TdXqcKtPfjTHgh67exBj/Kca4Icb47zHG70MIf4kx3hBj/CDG2BVj/J8Y44sxxuEnrx8WYyxijJeezP968vx/xBi/jzH+V4xx4ql+vaIoXiqK4q8hhGPN+P2BHqfq9DhVp8epMv3dN20/dJ10ewjh30IIHSGEDSGEEyGEfwwh/E0I4U8hhFkhhH+oc//fhRCWhhDOD79N8CtzFgunQY9TdXqcqtPjVJn+7kVVhq7/LIriraIououi+Kkoim1FUXxYFMWJoij+O4Twagjhb+vcv7Eoio+LojgeQng9hDC1KVVD3+lxqk6PU3V6nCrT373I/ZnLZtlXG2KMk0IIz4cQrg0hjAq//T4/rHP/NzXHP4YQxjS6QDhDepyq0+NUnR6nyvR3L6qy0lWU8j+HED4LIVxeFMU5IYRlIYTY9KqgcfQ4VafHqTo9TpXp715UZegqGxtCOBxC+CHGODnU/wxpv8QYR8QYR4bfGmd4jHFkjHFQNxEtocepOj1O1elxqkx/l1R16FoUQpgXQvg+/DZpb2jga/81hPBTCGFGCGHNyeM/NfD1oS/0OFWnx6k6PU6V6e+SWBTl1UAAAAAapaorXQAAAAOCoQsAACAjQxcAAEBGhi4AAICMDF0AAAAZDevlvK82bD8D/jkFA4webz96vH/0ePvR432nv9uP/u4fPd5+/rDHrXQBAABkZOgCAADIyNAFAACQkaELAAAgI0MXAABARoYuAACAjAxdAAAAGRm6AAAAMjJ0AQAAZGToAgAAyMjQBQAAkJGhCwAAICNDFwAAQEaGLgAAgIwMXQAAABkNa3UBwOlZuXJlkpctW9ZzPGPGjOTcpk2bktzR0ZGvMAAAEla6AAAAMjJ0AQAAZBSLoqh3vu7JwWr16tVJvueee5Lc3d2d5M8//zzJV155ZZ7CfhNzvngFtU2Pd3V1JfmKK65I8qFDh3qOY0zbYPv27UmeMmVKg6trKj3eP23T42W//vprkvfs2dNz/MADDyTn3nnnnabU1CR6vO/apr/L77cWLlyY5Ndff73neO/evcm5c845J19hzae/+6dteryZ1qxZk+QFCxYkedWqVUletGhR9ppq/GGPW+kCAADIyNAFAACQkaELAAAgI18Z3wfvvfdekh966KEkDxlSf3Yt76+B0zFq1Kgk33bbbUleu3ZtE6uB/I4dO5bkSZMm9RxffPHFybmjR48mecyYMfkKg9Nw4sSJJL/99ttJPnLkSM/x1q1bk3OzZs3KVxi0gfK/BytWrEhy+b32kiVLknz11VcnubOzs3HF9ZGVLgAAgIwMXQAAABkZugAAADKyp6sPvvjiiyT//PPPLaqEwWzEiBFJnjhxYosqgdbbv39/kg8fPpxke7oYaIYPH57k6dOnJ7n22VwHDhxoSk0wUJWfebthw4Yk9/Z3ZPz48UmeOnVqYwo7A1a6AAAAMjJ0AQAAZGToAgAAyMierj+wa9euJJefBVA2bdq0JG/atCnJo0ePbkhdDG7lvYTbt29vUSXQekVRtLoEOCOPPPJIkt98882e488++6zZ5cCAsmfPniTPnz+/X/e/8cYbSR43btwZ13SmrHQBAABkZOgCAADIyNAFAACQkT1dIYQvv/wyyTfffHOSDx06VPf+Z555JskdHR2NKQxqHD9+PMnlvYf1fPDBB0m+5JJLkqxnaTcxxiQfO3asRZXA6Zk8efIpz73yyitJXrlyZZI9h46q6erqSvJdd93Vr/vnzp2b5KuuuuqMa2o0K10AAAAZGboAAAAyMnQBAABkZE9XCOG1115L8r59++peP2fOnCTPnDmz4TVB2dixY5P84IMPJvnee+895b3lcxdccEGSyz0N7WbHjh1Jvuyyy1pUCZye2mfPlfcobtmyJcm33nprM0qCpuns7Ezyxx9/XPf6c889N8lLly5N8vDhwxtTWANZ6QIAAMjI0AUAAJCRoQsAACCjQbmn68cff0zyqlWrkjxkSDqLlve/lJ+XAa1w9913J7neni5oR+Wfxeedd17P8XfffZec2717d1NqglzKz56r5Tl0VN22bduSXO/vQwi/38M1EJ/LVWalCwAAICNDFwAAQEaGLgAAgIwGzZ6urq6unuM///nP/bp3xYoVSZ40aVIjSoKG6u7u7jku74WBdjRy5Mgkz549u+d4/fr1zS4HgAZ5+umnk1z7nLoQfr+na+7cuUleuHBhnsIy8s4MAAAgI0MXAABARoYuAACAjAbNnq7333+/53jr1q11ry1/bvTOO+/MURI0VO0+rt6ebwEA0EzLly/vOV67dm1yrvy+5YYbbkjymjVrkjxsWPuNMFa6AAAAMjJ0AQAAZNR+a3N9tG3btiTPmzfvlNfWfg1xCCGsXr06yeWvLQZgYDl48GCrS4AzUvuV2T4iThV8/fXXSa79SOH+/fvr3vvoo48medSoUQ2rq1WsdAEAAGRk6AIAAMjI0AUAAJBRZfZ0dXV1Jfn666/v872XX355kkePHt2QmgBojnXr1iX5hRdeaFElcHrs46Jqyt+RsG/fvlNeO23atCTPnDkzS02tZKULAAAgI0MXAABARoYuAACAjCqzp+v5559P8pAhfZ8nFy9e3OhyoOm6u7t7jnvr/3fffTfJc+bMyVITNNKsWbN6jtevX9/CSqC5pkyZ0uoSoFcbNmxI8rPPPpvkevsWN2/enOSxY8c2rrABwkoXAABARoYuAACAjAxdAAAAGbXtnq4DBw4keePGjX2+d/78+Um+8MILG1ITtFLtPq7envdSfnbGihUrkjxu3LiG1QWNMnHixFOe++WXX5J8+PDhJHd0dGSpCZrhoosuanUJ8DtHjhxJ8nPPPZfk2r3mIYQwdOjQnuPHH388OVfFPVxlVroAAAAyMnQBAABkZOgCAADIqG33dE2fPj3JBw8erHt9Z2dnz/FLL72UpSZopSVLlvQcP/nkk/26t7zHq/a1YKCo3Q9QVhRFko8fP567HIBB5dChQ0m+6aabkrxz58669z/11FM9xw8//HDjCmsTVroAAAAyMnQBAABkZOgCAADIqG33dH377bdJrn1G0R9ZvHhxz/GIESOy1AStdM0117S6BMiqdi/v1KlTk3M7duxI8osvvpjkJ554Il9hkNmJEydaXQKEvXv3Jrm3PVxlt99+eyPLaTtWugAAADIydAEAAGRk6AIAAMiobfZ0lb/Pv7u7u1/32+9C1d1xxx09x5MnT07O7dq1q+69S5cuTfJ9992X5PPPP/8Mq4PGmjNnTpK/+uqrJC9btqyZ5UBWmzdvTnLtz3tolq6urn5dX/45PWHChEaW03asdAEAAGRk6AIAAMjI0AUAAJDRgN3TdeDAgSRv3LgxyeXncp111llJXr58eZJHjx7dwOpgYJsxY0aSd+/eXff63p5zBwNdjDHJQ4cObVEl0Dfl9yXXXnttz/Enn3zS7HKgVwsWLOjX9YsWLUryyJEjG1lO2/FOCwAAICNDFwAAQEaGLgAAgIwG7J6uo0ePJrm8x6vs0ksvTfLixYsbXRK0jfvvvz/J69ata1El0Bzl58d89NFHSb7uuuuaWQ70qrzvcOzYsae89q233kqy53TRDN98802Sy+/Ny15++eUk+7mbstIFAACQkaELAAAgI0MXAABARgN2Txdw+sp7HGuf/xKCZ8DQ/l599dUkl5//ctlllzWzHDhjtc9X3LJlS3Kut700kMPOnTuTfPjw4brXn3322UkuPz9xsLPSBQAAkJGhCwAAIKMB+/HCCRMmJPmWW25JcvnrU4H/19HRkeQPP/ywRZVAHrNnz07yp59+muQRI0Y0sxw4Y4899ljP8Y4dO5Jz8+bNa3Y5EDo7O5M8fvz4JP/www9JvvHGG7PX1M6sdAEAAGRk6AIAAMjI0AUAAJBRLIqi3vm6JxmQfD9n/+jx9qPH+0ePtx893nf6u/3o7/7R4+3nD3vcShcAAEBGhi4AAICMDF0AAAAZGboAAAAyMnQBAABkZOgCAADIyNAFAACQkaELAAAgI0MXAABARoYuAACAjAxdAAAAGcWiKFpdAwAAQGVZ6QIAAMjI0AUAAJCRoQsAACAjQxcAAEBGhi4AAICMDF0AAAAZ/S/AUXhtUtNabwAAAABJRU5ErkJggg==\n", 142 | "text/plain": [ 143 | "
" 144 | ] 145 | }, 146 | "metadata": { 147 | "needs_background": "light" 148 | }, 149 | "output_type": "display_data" 150 | }, 151 | { 152 | "data": { 153 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAA10AAACLCAYAAACa9PPwAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8li6FKAAAMeklEQVR4nO3dbaiV5ZoH8Os20yC1U82YUCLOVMZARuEk6FiQOR+iE6UcTLIp8jAT2/oQpyCDJiSo4dQhGlTmw0xML4fGSYVUgmC0F7KSXX5JIrGCRhqb2ExiRi+mz3xot9j3Ou7l3rnvZ7me/fvBhvvf/bT3ZV5t1sWz7vWkqqoCAACAMiZ0uwAAAIAmM3QBAAAUZOgCAAAoyNAFAABQkKELAACgIEMXAABAQYYuAACAgnpq6EopHRnydTyl9O2QfNspfN93U0orT3LNv6eU9g/+3Ft/6c+CTvQ4Taa/aTo9TtPp8V+up4auqqqm/PwVEf8dEb8e8s/+WPjH74mIv4+IvYV/DuOYHqfJ9DdNp8dpOj3+y/XU0HUyKaUzUkoPp5Q+TSkNpJT+mFL61eDe2Sml/0gp/V9K6VBKaXdK6dyU0h8i4q8j4l8Hp/Q/nOh7V1X1z1VVvRYRP9T4R4KMHqfJ9DdNp8dpOj0+vEYNXRHxQET8bUT8TURcFBFHI+Kpwb3fRsTEiLgwIv4sIu6JiB+qqvpdRPRHxG8Hp/Tf1V41jJwep8n0N02nx2k6PT6Mpg1d/xARD1ZV9T9VVX0XEWsjYnlKKcVPf+l/HhF/WVXVj1VV9VdV9U03i4VfQI/TZPqbptPjNJ0eH8bEbhcwVgb/MmdGxCsppWrI1oSIOD8i/i0iZkTEppTSlIh4LiIerqrqWO3Fwi+gx2ky/U3T6XGaTo931pg7XVVVVRHxeURcV1XVr4Z8nVVV1UBVVd9XVfWPVVVdFhHXRMRvIuLnTz6phvu+cLrQ4zSZ/qbp9DhNp8c7a8zQNehfIuKfUkozIyJSStNTSr8eXF+fUvqrlNKEiDgcET9GxM+T9f9GxF90+sYppUkppbMiIkXEmSmlswYneqiTHqfJ9DdNp8dpOj0+jKYNXb+PiP+KiJ0ppa8j4u2IuGpw78KIeDkivo6fPmrylYj4z8G9pyLi71JKX6WUfj/M934zIr4d/H7PDa7nl/hDQAd6nCbT3zSdHqfp9Pgw0k93AgEAACihaXe6AAAATiuGLgAAgIIMXQAAAAUZugAAAAoydAEAABQ08ST7Ptqw9/TM8wpOE3q89+jx0dHjvUePj5z+7j36e3T0eO85YY+70wUAAFCQoQsAAKAgQxcAAEBBhi4AAICCDF0AAAAFGboAAAAKMnQBAAAUZOgCAAAoyNAFAABQkKELAACgIEMXAABAQYYuAACAggxdAAAABRm6AAAACjJ0AQAAFDSx2wWcjgYGBrI8ffr0LL/00ktZXrZsWfGaAMa7DRs2tNarV6/O9pYuXZrlzZs311ITjNTll1+e5Q8//DDL3377bWs9adKkWmoC6uNOFwAAQEGGLgAAgIK8vfAE9u3bl+UJE/LZ9KKLLqqzHBhze/bsyfK8efOyvGXLlizfdNNNWW7/fwLqsGPHjmH32nu2/ff4nDlzitQEI5VS6pjfeeed1vraa6+tpSYo6brrrsvytGnTWutnn3022zvnnHNqqambvHICAAAoyNAFAABQkKELAACgIGe6TmD37t1Znjp1apbnz59fZzlwyoZ+FHHEyR9z0P7x299//32WnemiG9rPbXXSfv7LmS66bcaMGVlu/8j4xYsXt9Y//vhjLTVBnbZt29ZaP/fcc9nevffeW3c5tfPKCQAAoCBDFwAAQEGGLgAAgIKc6YqIgwcPZvmRRx7J8n333VdnOTDmPvjggyx/9tlnHa+/5557sjxxol8V9Ja+vr5ulwCZxx57LMsPPPBAlnft2tVaHzlyJNubMmVKucKgkPvvvz/Lr7/+emv95Zdf1lxN97nTBQAAUJChCwAAoCBDFwAAQEEOasSfnm/55ptvsrxy5co6y4FT1v6MlwcffHBU//6qVauynFI65ZoAxrN58+Zl+YknnsjywoULh91bu3ZtucKgJuP9tYQ7XQAAAAUZugAAAAoydAEAABSUqqrqtN9xsykWL16c5QMHDmR57969WZ40aVLxmk7B+H7D7Og1ssfbzynOnj274/Xtz+H64YcfxrymMaTHR6cxPb5hw4bWevXq1R2vXb9+fZZ77Lldenzkera/N2/enOXly5e31lOnTs32vvrqq1pqqon+Hp2e7fH2Z3HNmDGjtW4/33Xs2LFaaqrJCXvcnS4AAICCDF0AAAAFGboAAAAKGpfP6Tp06FCWX3vttSzPnTs3y6f5GS74E1u2bBnV9bfeemuhSmDsnOwcF/SSp556qtslQFHTp0/P8tBzXO1nunbv3p3l+fPnlyusS9zpAgAAKMjQBQAAUJChCwAAoKBxeaZrz549HfdnzpxZUyVQxo4dOzruT548OcuPP/54yXIAaLNy5cosv/vuu12qBOqxZs2a1rr9dcfQvYiInTt31lJTndzpAgAAKMjQBQAAUJChCwAAoKBxeaarv7+/4/7atWtrqgTGzqefftpav/LKKx2vnTJlSpYvvPDCIjUBcGL79+/vdgnQNe3P6RoP3OkCAAAoyNAFAABQkKELAACgoHFzpmvoeZcnn3wy21u0aFGW586dW0tNMJbef//9EV/78MMPF6wEgJPZtGnTsHuHDx/O8vbt27N84403FqkJSlqyZElr3f6crkOHDmX56NGjWT7zzDPLFVYTd7oAAAAKMnQBAAAUNG7eXrhjx47WemBgINu74oorsjxx4rj5z0KD7Nq1a9i98847L8t33nln4Wqgu/r6+rpdAnT05ptvZnnNmjWt9Ysvvpjtbdu2LcveXkgvWrBgQWt91VVXZXt79uzJ8hdffJHlmTNnliusJu50AQAAFGToAgAAKMjQBQAAUNC4Obz03nvvtdYppWxv5cqVdZcDp+zjjz/O8rp164a9tv1M17Rp04rUBMDIzJo1K8vXX399a71x48Zs79VXX62lJihp6Me+T506NdurqirLb7/9dpaXL19errCauNMFAABQkKELAACgIEMXAABAQY0903XkyJEsb9++vbVufy7X1VdfXUtNMJYOHTqU5ePHjw977bJly0qXAwAwIu2vxd94440s7927N8vOdAEAANCRoQsAAKAgQxcAAEBBjT3TtWnTpiwfPHiwtV6xYkXd5cCYe+GFF4bda38u19133126HACAEVmyZEmWn3766S5VUh93ugAAAAoydAEAABRk6AIAACiosWe6Pvnkk2H3zj///BorgbFx+PDhLK9bt27Yay+++OIsz5o1q0hNUKelS5e21lu2bOl47YYNG7Lc19dXpCYATl1KKcvtz+lqAne6AAAACjJ0AQAAFGToAgAAKKixZ7qef/75YfduueWWGiuBsdH+/ubjx48Pe+1tt91WuhwAxlD7OcShvvvuuyy3n/GdNm1akZqgLlVVZXnr1q1dqqQcd7oAAAAKMnQBAAAUZOgCAAAoqDFnuvbv35/lzz//vEuVQBkDAwMd9y+44ILWetWqVaXLgeL27duX5ZM9mwt62bnnnjvsXvvv//7+/iwvXry4SE1Ql/bndLXnJnCnCwAAoCBDFwAAQEGGLgAAgIIac6Zr8+bNWT527FiWFy1a1FpfeumltdQEY+nll1/uuH/ZZZe11pMnTy5dDhT30EMPdbsEqM2KFSta6507d3axEihv3rx5WW5/9uiECc27L9S8PxEAAMBpxNAFAABQkKELAACgoJ4903X06NEsb9y4seP1d9xxR2vdxPeJ0jzt5xL37t3b8fqzzz67tT7jjDOK1AQlncpzuT766KMsz5kzZ0xqAmDsTZ8+Pcvtr809pwsAAIBRMXQBAAAUZOgCAAAoqGfPdLW/93PGjBlZvvLKK7N8++23F68JxlL7+5mvueaaLPf392fZGRaabv369a11X19fFyuBsbdw4cLW+oYbbsj23nrrrSx73ihNc9ddd2X5mWeeyfL+/fuzfMkllxSvaay50wUAAFCQoQsAAKCgVFVVp/2Om5yWmvcZm2X1TI9//fXXWX700UezvGDBgtb65ptvrqWmLtHjo9MzPU6LHh85/d179PfojIseP3DgQJZnz56d5a1bt2a5/S24p5kT9rg7XQAAAAUZugAAAAoydAEAABTkTFfzeK/06Ojx3qPHR0eP9x49PnL6u/fo79HR473HmS4AAIC6GboAAAAKMnQBAAAUZOgCAAAoyNAFAABQkKELAACgIEMXAABAQYYuAACAggxdAAAABRm6AAAACjJ0AQAAFJSqqup2DQAAAI3lThcAAEBBhi4AAICCDF0AAAAFGboAAAAKMnQBAAAUZOgCAAAo6P8B8D+87+5Y0vAAAAAASUVORK5CYII=\n", 154 | "text/plain": [ 155 | "
" 156 | ] 157 | }, 158 | "metadata": { 159 | "needs_background": "light" 160 | }, 161 | "output_type": "display_data" 162 | } 163 | ], 164 | "source": [ 165 | "dataset_train,dataset_test,x_train_origin,y_train_origin,x_test_origin,y_test_origin = buildDataSet()\n", 166 | "print(\"Checking shapes for class 0 (train) : \",dataset_train[0].shape)\n", 167 | "print(\"Checking shapes for class 0 (test) : \",dataset_test[0].shape)\n", 168 | "print(\"Checking first samples\")\n", 169 | "for i in range(2):\n", 170 | " DrawPics(dataset_train[i],5,template='Train {}',classnumber=i)\n", 171 | " DrawPics(dataset_test[i],5,template='Test {}',classnumber=i)" 172 | ] 173 | }, 174 | { 175 | "cell_type": "code", 176 | "execution_count": 5, 177 | "metadata": {}, 178 | "outputs": [ 179 | { 180 | "data": { 181 | "text/plain": [ 182 | "60000" 183 | ] 184 | }, 185 | "execution_count": 5, 186 | "metadata": {}, 187 | "output_type": "execute_result" 188 | } 189 | ], 190 | "source": [ 191 | "len(x_train_origin)" 192 | ] 193 | }, 194 | { 195 | "cell_type": "code", 196 | "execution_count": 6, 197 | "metadata": {}, 198 | "outputs": [], 199 | "source": [ 200 | "def build_network(input_shape, embeddingsize):\n", 201 | " '''\n", 202 | " Define the neural network to learn image similarity\n", 203 | " Input : \n", 204 | " input_shape : shape of input images\n", 205 | " embeddingsize : vectorsize used to encode our picture \n", 206 | " '''\n", 207 | " # Convolutional Neural Network\n", 208 | " network = Sequential()\n", 209 | " network.add(Conv2D(128, (7,7), activation='relu',\n", 210 | " input_shape=input_shape,\n", 211 | " kernel_initializer='he_uniform',\n", 212 | " kernel_regularizer=l2(2e-4)))\n", 213 | " network.add(MaxPooling2D())\n", 214 | " network.add(Conv2D(128, (3,3), activation='relu', kernel_initializer='he_uniform',\n", 215 | " kernel_regularizer=l2(2e-4)))\n", 216 | " network.add(MaxPooling2D())\n", 217 | " network.add(Conv2D(256, (3,3), activation='relu', kernel_initializer='he_uniform',\n", 218 | " kernel_regularizer=l2(2e-4)))\n", 219 | " network.add(Flatten())\n", 220 | " network.add(Dense(4096, activation='relu',\n", 221 | " kernel_regularizer=l2(1e-3),\n", 222 | " kernel_initializer='he_uniform'))\n", 223 | " \n", 224 | " \n", 225 | " network.add(Dense(embeddingsize, activation=None,\n", 226 | " kernel_regularizer=l2(1e-3),\n", 227 | " kernel_initializer='he_uniform'))\n", 228 | " \n", 229 | " #Force the encoding to live on the d-dimentional hypershpere\n", 230 | " network.add(Lambda(lambda x: K.l2_normalize(x,axis=-1)))\n", 231 | " \n", 232 | " return network\n", 233 | "\n", 234 | "class TripletLossLayer(Layer):\n", 235 | " def __init__(self, alpha, **kwargs):\n", 236 | " self.alpha = alpha\n", 237 | " super(TripletLossLayer, self).__init__(**kwargs)\n", 238 | " \n", 239 | " def triplet_loss(self, inputs):\n", 240 | " anchor, positive, negative = inputs\n", 241 | " p_dist = K.sum(K.square(anchor-positive), axis=-1)\n", 242 | " n_dist = K.sum(K.square(anchor-negative), axis=-1)\n", 243 | " return K.sum(K.maximum(p_dist - n_dist + self.alpha, 0), axis=0)\n", 244 | " \n", 245 | " def call(self, inputs):\n", 246 | " loss = self.triplet_loss(inputs)\n", 247 | " self.add_loss(loss)\n", 248 | " return loss" 249 | ] 250 | }, 251 | { 252 | "cell_type": "code", 253 | "execution_count": 7, 254 | "metadata": {}, 255 | "outputs": [], 256 | "source": [ 257 | "def build_model(input_shape, network, margin=0.2):\n", 258 | " '''\n", 259 | " Define the Keras Model for training \n", 260 | " Input : \n", 261 | " input_shape : shape of input images\n", 262 | " network : Neural network to train outputing embeddings\n", 263 | " margin : minimal distance between Anchor-Positive and Anchor-Negative for the lossfunction (alpha)\n", 264 | " \n", 265 | " '''\n", 266 | " # Define the tensors for the three input images\n", 267 | " anchor_input = Input(input_shape, name=\"anchor_input\")\n", 268 | " positive_input = Input(input_shape, name=\"positive_input\")\n", 269 | " negative_input = Input(input_shape, name=\"negative_input\") \n", 270 | " \n", 271 | " # Generate the encodings (feature vectors) for the three images\n", 272 | " encoded_a = network(anchor_input)\n", 273 | " encoded_p = network(positive_input)\n", 274 | " encoded_n = network(negative_input)\n", 275 | " \n", 276 | " #TripletLoss Layer\n", 277 | " loss_layer = TripletLossLayer(alpha=margin,name='triplet_loss_layer')([encoded_a,encoded_p,encoded_n])\n", 278 | " \n", 279 | " # Connect the inputs with the outputs\n", 280 | " network_train = Model(inputs=[anchor_input,positive_input,negative_input],outputs=loss_layer)\n", 281 | " \n", 282 | " # return the model\n", 283 | " return network_train" 284 | ] 285 | }, 286 | { 287 | "cell_type": "code", 288 | "execution_count": 9, 289 | "metadata": {}, 290 | "outputs": [ 291 | { 292 | "name": "stdout", 293 | "output_type": "stream", 294 | "text": [ 295 | "Model: \"model_2\"\n", 296 | "__________________________________________________________________________________________________\n", 297 | "Layer (type) Output Shape Param # Connected to \n", 298 | "==================================================================================================\n", 299 | "anchor_input (InputLayer) (None, 28, 28, 1) 0 \n", 300 | "__________________________________________________________________________________________________\n", 301 | "positive_input (InputLayer) (None, 28, 28, 1) 0 \n", 302 | "__________________________________________________________________________________________________\n", 303 | "negative_input (InputLayer) (None, 28, 28, 1) 0 \n", 304 | "__________________________________________________________________________________________________\n", 305 | "sequential_2 (Sequential) (None, 10) 4688522 anchor_input[0][0] \n", 306 | " positive_input[0][0] \n", 307 | " negative_input[0][0] \n", 308 | "__________________________________________________________________________________________________\n", 309 | "triplet_loss_layer (TripletLoss [(None, 10), (None, 0 sequential_2[1][0] \n", 310 | " sequential_2[2][0] \n", 311 | " sequential_2[3][0] \n", 312 | "==================================================================================================\n", 313 | "Total params: 4,688,522\n", 314 | "Trainable params: 4,688,522\n", 315 | "Non-trainable params: 0\n", 316 | "__________________________________________________________________________________________________\n", 317 | "['loss']\n" 318 | ] 319 | }, 320 | { 321 | "name": "stderr", 322 | "output_type": "stream", 323 | "text": [ 324 | "/opt/conda/lib/python3.7/site-packages/keras/engine/training_utils.py:819: UserWarning: Output triplet_loss_layer missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to triplet_loss_layer.\n", 325 | " 'be expecting any data to be passed to {0}.'.format(name))\n" 326 | ] 327 | } 328 | ], 329 | "source": [ 330 | "network = build_network(input_shape,embeddingsize=10)\n", 331 | "network_train = build_model(input_shape,network)\n", 332 | "optimizer = Adam(lr = 0.00006)\n", 333 | "network_train.compile(loss=None,optimizer=optimizer)\n", 334 | "network_train.summary()\n", 335 | "plot_model(network_train,show_shapes=True, show_layer_names=True, to_file='02 model.png')\n", 336 | "print(network_train.metrics_names)\n", 337 | "n_iteration=0\n", 338 | "#network_train.load_weights('mnist-160k_weights.h5')" 339 | ] 340 | }, 341 | { 342 | "cell_type": "code", 343 | "execution_count": 10, 344 | "metadata": {}, 345 | "outputs": [ 346 | { 347 | "name": "stdout", 348 | "output_type": "stream", 349 | "text": [ 350 | "[[-0.13133799 -0.09094287 -0.4855275 -0.39615053 -0.23171382 -0.13620552\n", 351 | " 0.4054796 -0.35857293 -0.44204152 0.14551745]]\n" 352 | ] 353 | } 354 | ], 355 | "source": [ 356 | "featured_img = network.predict(np.ones((1,img_rows,img_cols,1)))\n", 357 | "print(featured_img)" 358 | ] 359 | }, 360 | { 361 | "cell_type": "code", 362 | "execution_count": 11, 363 | "metadata": {}, 364 | "outputs": [], 365 | "source": [ 366 | "def get_batch_random(batch_size,s=\"train\"):\n", 367 | " \"\"\"\n", 368 | " Create batch of APN triplets with a complete random strategy\n", 369 | " \n", 370 | " Arguments:\n", 371 | " batch_size -- integer \n", 372 | "\n", 373 | " Returns:\n", 374 | " triplets -- list containing 3 tensors A,P,N of shape (batch_size,w,h,c)\n", 375 | " \"\"\"\n", 376 | " if s == 'train':\n", 377 | " X = dataset_train\n", 378 | " else:\n", 379 | " X = dataset_test\n", 380 | "\n", 381 | " m, w, h,c = X[0].shape\n", 382 | " \n", 383 | " \n", 384 | " # initialize result\n", 385 | " triplets=[np.zeros((batch_size,h, w,c)) for i in range(3)]\n", 386 | " \n", 387 | " for i in range(batch_size):\n", 388 | " #Pick one random class for anchor\n", 389 | " anchor_class = np.random.randint(0, nb_classes)\n", 390 | " nb_sample_available_for_class_AP = X[anchor_class].shape[0]\n", 391 | " \n", 392 | " #Pick two different random pics for this class => A and P\n", 393 | " [idx_A,idx_P] = np.random.choice(nb_sample_available_for_class_AP,size=2,replace=False)\n", 394 | " \n", 395 | " #Pick another class for N, different from anchor_class\n", 396 | " negative_class = (anchor_class + np.random.randint(1,nb_classes)) % nb_classes\n", 397 | " nb_sample_available_for_class_N = X[negative_class].shape[0]\n", 398 | " \n", 399 | " #Pick a random pic for this negative class => N\n", 400 | " idx_N = np.random.randint(0, nb_sample_available_for_class_N)\n", 401 | "\n", 402 | " triplets[0][i,:,:,:] = X[anchor_class][idx_A,:,:,:]\n", 403 | " triplets[1][i,:,:,:] = X[anchor_class][idx_P,:,:,:]\n", 404 | " triplets[2][i,:,:,:] = X[negative_class][idx_N,:,:,:]\n", 405 | "\n", 406 | " return triplets\n", 407 | "\n", 408 | "def drawTriplets(tripletbatch, nbmax=None):\n", 409 | " \"\"\"display the three images for each triplets in the batch\n", 410 | " \"\"\"\n", 411 | " labels = [\"Anchor\", \"Positive\", \"Negative\"]\n", 412 | "\n", 413 | " if (nbmax==None):\n", 414 | " nbrows = tripletbatch[0].shape[0]\n", 415 | " else:\n", 416 | " nbrows = min(nbmax,tripletbatch[0].shape[0])\n", 417 | " \n", 418 | " for row in range(nbrows):\n", 419 | " fig=plt.figure(figsize=(16,2))\n", 420 | " \n", 421 | " for i in range(3):\n", 422 | " subplot = fig.add_subplot(1,3,i+1)\n", 423 | " axis(\"off\")\n", 424 | " plt.imshow(tripletbatch[i][row,:,:,0],vmin=0, vmax=1,cmap='Greys')\n", 425 | " subplot.title.set_text(labels[i])" 426 | ] 427 | }, 428 | { 429 | "cell_type": "code", 430 | "execution_count": 12, 431 | "metadata": {}, 432 | "outputs": [], 433 | "source": [ 434 | "def compute_dist(a,b):\n", 435 | " return np.sum(np.square(a-b))\n", 436 | "\n", 437 | "def get_batch_hard(draw_batch_size,hard_batchs_size,norm_batchs_size,network,s=\"train\"):\n", 438 | " \"\"\"\n", 439 | " Create batch of APN \"hard\" triplets\n", 440 | " \n", 441 | " Arguments:\n", 442 | " draw_batch_size -- integer : number of initial randomly taken samples \n", 443 | " hard_batchs_size -- interger : select the number of hardest samples to keep\n", 444 | " norm_batchs_size -- interger : number of random samples to add\n", 445 | "\n", 446 | " Returns:\n", 447 | " triplets -- list containing 3 tensors A,P,N of shape (hard_batchs_size+norm_batchs_size,w,h,c)\n", 448 | " \"\"\"\n", 449 | " if s == 'train':\n", 450 | " X = dataset_train\n", 451 | " else:\n", 452 | " X = dataset_test\n", 453 | "\n", 454 | " m, w, h,c = X[0].shape\n", 455 | " \n", 456 | " \n", 457 | " #Step 1 : pick a random batch to study\n", 458 | " studybatch = get_batch_random(draw_batch_size,s)\n", 459 | " \n", 460 | " #Step 2 : compute the loss with current network : d(A,P)-d(A,N). The alpha parameter here is omited here since we want only to order them\n", 461 | " studybatchloss = np.zeros((draw_batch_size))\n", 462 | " \n", 463 | " #Compute embeddings for anchors, positive and negatives\n", 464 | " A = network.predict(studybatch[0])\n", 465 | " P = network.predict(studybatch[1])\n", 466 | " N = network.predict(studybatch[2])\n", 467 | " \n", 468 | " #Compute d(A,P)-d(A,N)\n", 469 | " studybatchloss = np.sum(np.square(A-P),axis=1) - np.sum(np.square(A-N),axis=1)\n", 470 | " \n", 471 | " #Sort by distance (high distance first) and take the \n", 472 | " selection = np.argsort(studybatchloss)[::-1][:hard_batchs_size]\n", 473 | " \n", 474 | " #Draw other random samples from the batch\n", 475 | " selection2 = np.random.choice(np.delete(np.arange(draw_batch_size),selection),norm_batchs_size,replace=False)\n", 476 | " \n", 477 | " selection = np.append(selection,selection2)\n", 478 | " \n", 479 | " triplets = [studybatch[0][selection,:,:,:], studybatch[1][selection,:,:,:], studybatch[2][selection,:,:,:]]\n", 480 | " \n", 481 | " return triplets" 482 | ] 483 | }, 484 | { 485 | "cell_type": "code", 486 | "execution_count": 13, 487 | "metadata": {}, 488 | "outputs": [ 489 | { 490 | "name": "stdout", 491 | "output_type": "stream", 492 | "text": [ 493 | "Checking batch width, should be 3 : 3\n", 494 | "Shapes in the batch A:(2, 28, 28, 1) P:(2, 28, 28, 1) N:(2, 28, 28, 1)\n", 495 | "Shapes in the hardbatch A:(2, 28, 28, 1) P:(2, 28, 28, 1) N:(2, 28, 28, 1)\n" 496 | ] 497 | }, 498 | { 499 | "data": { 500 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAvEAAACLCAYAAADoHsZvAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8li6FKAAAR6UlEQVR4nO3dfZBU1ZnH8d+DvKmM8iIvosObyuKqCUWtwT9k1yVlFJfgCxJFjaMk1lqrriuik4pbkSVLdtFSQTBsFmphFUbDABriSxYVTEmiiA5x0UQTRFgSRRcDDMiIwpz9497pveeG6Z6e6Z7uM/P9VFF1f33v7T4zVJ155vTTZ8w5JwAAAADh6FLqAQAAAADID0U8AAAAEBiKeAAAACAwFPEAAABAYCjiAQAAgMBQxAMAAACB6dRFvJnNNLNlpR4HABSKmQ0xswNmdkyWaw6Y2Yj2HBcAlJqZPWdmVaUeR6EEV8Sb2UtmtsfMepR6LABQCGa23cwa4uL6IzNbYma9WvNczrn/cc71cs4diZ/7JTP7duqaXs65bYUYOwBkE89vH5nZ8YnHvm1mLxX5df9kodY5N8E595/FfN32FFQRb2bDJI2T5CRNKulgUrKtegFAC3zdOddL0hhJ50r6xxKPBwAKpauk20s9iI4mqCJe0vWSXpW0VFLm7RAzW2pmj5jZM2a238w2mtlpifNnmdnzZvbH+LfB7yaes7uZPRrf97aZ/UXivjPjVay98blJqddcaGbPmtmnkv66mF84gM7BOfcHSc9JOtvMBpvZmnju2mpmNzVdZ2ZfMbPXzaw+ntcejB8fZmbOzLqa2WxFCx8L4lX+BfE1zsxON7PzzGxXchHCzC43s/+Oj7uY2XfM7D0z+8TMVphZ3/b8fgDoEO6XNMPMeqdPmNmoRI32rpl9I3Gun5n9NJ7nNpnZP5vZhsT5eWa2Mz7/hpmNix+/WNJ3JV0Vz31vxo+/FL8L0COu7c5OPFf/+B3RAXGeaGa/iq/7pZl9qWjfnVYKsYhfHv+7yMwGJs5NlfRPkvpI2ipptiSZWYWkFyT9TNJgSadLejFx3yRJT0jqLWmNpKYfct0k/VTSWkkDJN0mabmZ/Vni3mvi16mQtEEA0EZmVinpEkmbJT0u6feK5q4rJf3AzL4aXzpP0jzn3AmSTpO0Iv1czrl7JL0s6da4hebW1PlXJX0qaXzi4Wsk1cTHfy/pMkl/FY9hj6RHCvBlAuhcXpf0kqQZyQfjFpvnFc05AxTVcj80s7PiSx5RNEcNUrR4m+5n3yRptKS+8XPUmllP59zPJP1A0o/jue/LyZucc4ckrY5fr8k3JP3cOfexmY2R9B+S/lZSP0k/krSm3Fq5gynizex8SUMlrXDOvSHpPUU/bJqsds695pw7rKjIHx0/PlHSLufcA865z5xz+51zGxP3bXDOPRv3jz4mqek/+jxJvST9q3Puc+fcOklPy/8P/4lz7hfOuUbn3GeF/poBdCpPmdleRQsCP5f075LOl1Qdz12/krRY0jfj67+QdLqZneScOxAX5K3xuOJ5LV70uCR+TIp+gN3jnPt9/ENvpqQrzaxrK18LQOf1PUm3mVn/xGMTJW13zi1xzh12ztVJWqVonjlG0mRJ9zrnDjrnfi3J62d3zi1zzn0S3/uApB6Skout2dTIr+mSCxg3SfqRc26jc+5I3Ed/SFFtWDaCKeIV/fa11jm3O8418n8j25U4PqioAJekSkUFf3PS9/WMf0ANlrTTOdeYOL9D0imJvLPlwweArC5zzvV2zg11zv2dojnoj865/YlrknPQtySNlPRO/DbzxFa+bo2kK+IVpisk1TnndsTnhkp6Mn47ea+k30g6Imng0Z8KAI7OOfeWosXQ7yQeHippbNMcE88z1ypaee+vqJc+WWt5dZeZ3WlmvzGzffG9J0o6qYVDWifpWDMba2ZDFS3+PpkY152pcVUqmpfLRhCrKWZ2rKK3OY4xs6aiu4ek3mb25ebvlBT9h0/Ncc3RfCCp0sy6JAr5IZJ+m7jGteJ5AaAlPpDU18wqEoX8EEl/kCTn3O8kTTWzLoqK75Vm1u8oz5N1nnLO/drMdkiaIH8lSormz2nOuV+07UsBAEnSvZLqJD0Q552KWlguTF8Yr8QflnSq/r/2qkycHyepWtJXJb3tnGs0sz2SLL4k19zXaGYrFNWIH0l6OjHX7pQ02zk3O/8vsf2EshJ/maLVnz9X9JvSaElnKur1vD7HvU9LGmRm/xB/kKHCzMa24DU3KurDutvMupnZBZK+rqh/HgCKyjm3U9IvJf2LmfWMP1T1LUXtgjKz68ysf7zIsDe+7chRnuojSbn2hK9R1P/+l5JqE4//m6TZ8SpV0we/Lm3t1wSgc3PObZX0Y0XzjRTVaCPN7JtxrdXNzM41szPjNufVkmaa2XFmNkp+zVehqMj/X0ldzex7kk5InP9I0rB4oaM5NZKuUrT6n1zAWCTp5niV3szseDP7m7jlsGyEUsRXSVoS73+8q+mfog+hXqss7yjEv1VdqKgA3yXpd2rBTjLOuc8Vfeh1gqTdkn4o6Xrn3Dtt/WIAoIWmShqmaFX+SUW9oc/H5y6W9LaZHVD0Iderm/lszjxF/aV7zOzhZl7ncUkXSFqXaFlsuneNpLVmtl/R7mAtWQQBgObMknS8lKnRvibpakXz3C5JcxR1W0jSrYpaZHYp+tzi44p60yXpvxTt5PVbRa2Gn8lvt2lakPjEzOqONpD4M5KfKmqTeS7x+OuK+uIXKPpA/1ZJN7Tuyy0ec46OEAAAAJQ3M5sjaZBzrsP81dW2CGUlHgAAAJ1IvIf8l+KWlq8oail8Mtd9nUUQH2wFAABAp1OhqIVmsKSPFX0g9iclHVEZoZ0GAAAACAztNAAAAEBgKOIBAACAwOTqiafXJjyW+xIAKcx14WGuA/LHXBeeZuc6VuIBAACAwFDEAwAAAIGhiAcAAAACQxEPAAAABIYiHgAAAAgMRTwAAAAQGIp4AAAAIDAU8QAAAEBgKOIBAACAwFDEAwAAAIGhiAcAAAAC07XUAwAAAABaa+HChV5+4YUXvDx37lwvV1ZWFn1M7YGVeAAAACAwFPEAAABAYCjiAQAAgMDQE98Czjkv79u3z8vjx4/38ogRI7y8cuXK4gwMAPKwevVqL1955ZVeXrJkiZerqqqKPiYAaI2HH344c3zPPfd45w4ePOjlU0891cvz5s0r3sDaESvxAAAAQGAo4gEAAIDA0E7TAl988YWX+/btm/X60047rZjDAYAW2bp1q5enTp3qZTPz8s033+zlKVOmePm4444r4OgAoOU2btzo5ZkzZ2aOGxoast57zjnnFGNIJcdKPAAAABAYingAAAAgMBTxAAAAQGDoiW+BJ554Iq/rBw8eXKSRAEDLLV++3MuHDx/Oev2hQ4e8PH/+fC9XV1cXZmAAkMPu3bu9PGHCBC/X19c3e++AAQO8fN111xVuYGWElXgAAAAgMBTxAAAAQGAo4gEAAIDAmHMu2/msJzuqV155xcvnn3++l3N8z/Tmm296uZ33J7XclwBI6TBz3f79+zPHw4YN887t3bs3673pua13795efu+997zcp0+fVoywYJjrgPwFM9d98MEHXh4yZEiL7033wC9durQQQyqVZuc6VuIBAACAwFDEAwAAAIGhiAcAAAACwz7xR7FmzRov5+qBX7JkiZdHjhxZ8DEBQEts2rQpc7xnz56s16Z7TLds2eLlnj17erlbt25tHB0AHF16vqqsrGzxvXPmzPHyjBkzCjKmcsdKPAAAABAYingAAAAgMBTxAAAAQGDoidef7gt///33Z70+3Ud6zTXXeJm+UQClUl1dnTk2y76V+uTJk71cUVFRlDEBQC6LFy/2cq75K2ngwIGFHk4QWIkHAAAAAkMRDwAAAASGIh4AAAAITKfsiU/v+37HHXd4ubGxMev9l156qZfpgQdQKps3b/byG2+8kTnO1VN67LHHFmVMAJCvdE98LlOmTMkcpz+b2FmwEg8AAAAEhiIeAAAACAxFPAAAABAYS/eHp2Q9GZIjR45kjmfNmuWd+/73v5/13l69enm5vr6+cAMrvJZvrAqgSbBz3cKFC718yy23ZI7TPfEnnXSSl7dv3+7lwHrkmeuA/JXNXLdr1y4vn3LKKV7O9Zmet956K3M8atSowg2s/DT7jWAlHgAAAAgMRTwAAAAQGIp4AAAAIDCdZp/4urq6zHGuHvi06dOnF3o4ANAqhw8f9vKKFSu8nO1zTjNmzPByYD3wADqQF198Ma/rL7/8ci+PHDmy2Ws///xzL9fW1ub1Wsk96CWpe/fued3fXliJBwAAAAJDEQ8AAAAEhiIeAAAACEyH3Sd+27ZtXh43blzm+MMPP/TOpb8H06ZN8/LixYu9nGvv0hIr68EBZSqYue7AgQNePvHEE72cnM/Sc1V6Xhw6dGiBR9eumOuA/JXNXFdVVeXlRx991MvpPvQtW7Z4efjw4ZnjRYsWeeeSfy9Dkrp0aduadUNDQ9axFRn7xAMAAAAdBUU8AAAAEJgOs8Vketu16upqL6dbaJKSb8lI0uzZs71c5u0zADqR9FyXj27duhVwJADQeuvXr/dyuuVl7NixXk5vKXnbbbdljhcuXJj1udpax6W3qLz22mvb9HyFwko8AAAAEBiKeAAAACAwFPEAAABAYILtiW9sbPTy9OnTvbxq1SovZ9t2be3atV4eNGhQIYYIAAX3zDPPtPjaUaNGeblfv36FHg4AtIuNGzd6efny5S2+94QTTvByTU2Nl++8804vv/vuu15+7LHHvExPPAAAAIBWoYgHAAAAAkMRDwAAAAQm2J74ZI+7JC1YsCDr9ck++Kuvvto7N2LEiMINDADaUXouTOY+ffp453r06NEuYwKAttq8ebOXJ0yY4OX6+vpm7500aZKXFy9e7OXu3bt7+dChQ1nHku6pLxesxAMAAACBoYgHAAAAAkMRDwAAAAQmmJ74999/38t33313XvdfccUVmeNly5Z557p04XcZAGFI71ec/rsXSXfddVexhwMARdHQ0JA1ZzNnzhwv9+zZ08uXXHKJl7dv3571+dJ1Y7mgegUAAAACQxEPAAAABIYiHgAAAAhMMD3x69at8/KqVauyXt+/f38vz5o1K3NMDzyAUOzevdvLGzZsyHr9GWeckTm++OKLizImAGirxsbGrDmfWm38+PFefuqpp7xcW1vr5bq6uqzPN3/+fC+n95UvF1SzAAAAQGAo4gEAAIDAUMQDAAAAgTHnXLbzWU8W02uvvebldL/TwYMHvdy1q9/e/+qrr3p5zJgxBRxdWWt+02gAzSnZXJfLpk2bvHzeeedlvf6qq67KHNfU1BRlTGWCuQ7IX9nMdQ8++KCX03/XItvfwMglXdvmeq6JEyd6eeXKlV5O15jtrNnBsxIPAAAABIYiHgAAAAgMRTwAAAAQmLLZJ76hocHLN9xwg5fTPfBpF1xwgZc7UQ88gA5s/fr1Xs7xOaac5wGgHEyePNnL6Z74Yrrpppu8fO+993q5xD3wLcZKPAAAABAYingAAAAgMBTxAAAAQGBK1vSTqwf+nXfeyXr/ueee6+XVq1cXZFwAUE7S+xvn2u+4rq4uc/zyyy9750aPHu3lioqKNo4OAFpn6NChXl66dKmXb7zxxoK91rRp07w8d+5cL/fo0aNgr9WeWIkHAAAAAkMRDwAAAASGIh4AAAAIjOXYU7hoGw7v2LHDy8OHD8/r/n379nmZ3s6M7A2zAI6mbDdXX7BggZdvv/32rNcn5/STTz7ZO5fsl5ekgQMHtnF0JcVcB+SvbOe6xsZGL3/88cdeXrZsmZcfeuihzPFZZ53lnbvvvvu8fPbZZ3s5lH3gY83OdazEAwAAAIGhiAcAAAACQxEPAAAABKZkPfH19fVeHjNmjJe3bdvm5WeffdbLF110kZdz7Z3cifCNAPJXtn2ihw4d8nJVVZWXa2trvVxZWZk5XrRokXfuwgsvLPDoSoq5Dshf2c51aBY98QAAAEBHQREPAAAABKZk7TQoGt5iBvLHXBce5jogf8x14aGdBgAAAOgoKOIBAACAwFDEAwAAAIGhiAcAAAACQxEPAAAABIYiHgAAAAgMRTwAAAAQGIp4AAAAIDAU8QAAAEBgKOIBAACAwFDEAwAAAIGhiAcAAAACQxEPAAAABIYiHgAAAAgMRTwAAAAQGHPOlXoMAAAAAPLASjwAAAAQGIp4AAAAIDAU8QAAAEBgKOIBAACAwFDEAwAAAIGhiAcAAAAC839X22VX5TMFTgAAAABJRU5ErkJggg==\n", 501 | "text/plain": [ 502 | "
" 503 | ] 504 | }, 505 | "metadata": { 506 | "needs_background": "light" 507 | }, 508 | "output_type": "display_data" 509 | }, 510 | { 511 | "data": { 512 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAvEAAACLCAYAAADoHsZvAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8li6FKAAATZUlEQVR4nO3de5DV5X3H8c8XkEsBQS6CCOKAUK2IGosaRasQQamhComAEeMtQ8fEiibxgo1ardGEpBkwRFNrE1sXDfFSCYiamiHj6ogapCogKChFkIvKyk25uE//+P12e76/sGf37J7ds8/u+zXjzPmc53d5Vmae/e6z3/NbCyEIAAAAQDzalHoCAAAAAApDEQ8AAABEhiIeAAAAiAxFPAAAABAZingAAAAgMhTxAAAAQGRadRFvZreb2cOlngcAFIuZHWFmO82sbZ5jdprZoKacFwCUmpktMrNvlnoexRJdEW9mi81sm5l1KPVcAKAYzOx9M/ssLa43m9mvzKxLfa4VQvjfEEKXEMIX6bUXm9lVmWO6hBDWFmPuAJBPur5tNrPOOe9dZWaLG/m+f7ZRG0I4L4TwUGPetylFVcSb2ZGSzpAUJI0v6WQy8u16AUAdfDWE0EXSlySNkPSPJZ4PABRLO0nXlnoSLU1URbykSyW9LOnXkqp/HWJmvzazOWa20Mx2mNkSMxucM36smf3ezD5JfxqckXPN9mb2H+l5y83sr3POOybdxapIx8Zn7nmfmT1tZrsknd2YXziA1iGEsEHSIknDzKyfmc1P1653zexbVceZ2clm9pqZbU/XtX9J3z/SzIKZtTOzu5RsfPw83eX/eXpMMLOjzOxUM9uUuwlhZhea2Rvp6zZmdpOZrTGzj81snpn1aMr/HwBahJmSvmdm3bMDZnZ0To22yswuyhnraWa/S9e5V83sn82sPGd8lpmtT8f/ZGZnpO+fK2mGpEnp2vc/6fuL098CdEhru2E51+qd/kb00DSfb2bL0uNeMrPhjfZ/p55iLOLL0v/GmlmfnLEpkv5J0iGS3pV0lySZWVdJ/y3pGUn9JB0l6fmc88ZLelRSd0nzJVV9kztI0u8kPSfpUEnXSCozs7/MOffi9D5dJZULABrIzAZIGifpdUmPSPpAydr1NUk/NLPR6aGzJM0KIRwsabCkedlrhRBukfSCpO+kLTTfyYy/LGmXpFE5b18saW76+h8kXSDpb9I5bJM0pwhfJoDW5TVJiyV9L/fNtMXm90rWnEOV1HK/MLNj00PmKFmj+irZvM32s78q6QRJPdJr/NbMOoYQnpH0Q0m/Sde+43NPCiHskfREer8qF0n6Ywhhi5l9SdK/S5omqaekX0qa39xauaMp4s1spKSBkuaFEP4kaY2SbzZVngghvBJC2K+kyD8hff98SZtCCD8NIXweQtgRQliSc155COHptH/0PyVV/UOfKqmLpHtCCHtDCH+QtED+H/ypEMKLIYTKEMLnxf6aAbQq/2VmFUo2BP4o6V8ljZR0Y7p2LZP0b5Kmpsfvk3SUmfUKIexMC/L6eETpupZueoxL35OSb2C3hBA+SL/p3S7pa2bWrp73AtB63SrpGjPrnfPe+ZLeDyH8KoSwP4SwVNLjStaZtpImSrothLA7hLBCkutnDyE8HEL4OD33p5I6SMrdbM1nrnxNl7uB8S1JvwwhLAkhfJH20e9RUhs2G9EU8Up++nouhPBRmufK/0S2Kef1biUFuCQNUFLw1yR7Xsf0G1Q/SetDCJU54+skHZ6T19d9+gCQ1wUhhO4hhIEhhKuVrEGfhBB25ByTuwZdKWmopLfTXzOfX8/7zpU0Id1hmiBpaQhhXTo2UNKT6a+TKyStlPSFpD4HvhQAHFgI4S0lm6E35bw9UNIpVWtMus58Q8nOe28lvfS5tZaru8zsu2a20sw+Tc/tJqlXHaf0B0mdzOwUMxuoZPP3yZx5fTczrwFK1uVmI4rdFDPrpOTXHG3NrKro7iCpu5kdX/OZkpJ/8Cm1HHMgGyUNMLM2OYX8EZJW5xwT6nFdAKiLjZJ6mFnXnEL+CEkbJCmE8I6kKWbWRknx/ZiZ9TzAdfKuUyGEFWa2TtJ58jtRUrJ+XhFCeLFhXwoASJJuk7RU0k/TvF5JC8s52QPTnfj9kvrr/2uvATnjZ0i6UdJoSctDCJVmtk2SpYfUtvZVmtk8JTXiZkkLctba9ZLuCiHcVfiX2HRi2Ym/QMnuz18p+UnpBEnHKOn1vLSWcxdI6mtm09MPMnQ1s1PqcM8lSvqwbjCzg8zsLElfVdI/DwCNKoSwXtJLku42s47ph6quVNIuKDO7xMx6p5sMFelpXxzgUpsl1fZM+LlK+t/PlPTbnPfvl3RXuktV9cGvv6vv1wSgdQshvCvpN0rWGymp0Yaa2dS01jrIzEaY2TFpm/MTkm43s78ws6Pla76uSor8rZLamdmtkg7OGd8s6ch0o6MmcyVNUrL7n7uB8YCkv0936c3MOpvZ36Yth81GLEX8NyX9Kn3+8aaq/5R8CPUbyvMbhfSnqnOUFOCbJL2jOjxJJoSwV8mHXs+T9JGkX0i6NITwdkO/GACooymSjlSyK/+kkt7Q36dj50pabmY7lXzIdXINn82ZpaS/dJuZza7hPo9IOkvSH3JaFqvOnS/pOTPboeTpYHXZBAGAmtwhqbNUXaONkTRZyTq3SdKPlHRbSNJ3lLTIbFLyucVHlPSmS9KzSp7ktVpJq+Hn8u02VRsSH5vZ0gNNJP2M5C4lbTKLct5/TUlf/M+VfKD/XUmX1e/LbTwWAh0hAAAAaN7M7EeS+oYQWsxfXW2IWHbiAQAA0Iqkz5Afnra0nKykpfDJ2s5rLaL4YCsAAABana5KWmj6Sdqi5AOxT5V0Rs0I7TQAAABAZGinAQAAACJDEQ8AAABEpraeeHpt4mO1HwIgg7UuPqx1QOFY6+JT41rHTjwAAAAQGYp4AAAAIDIU8QAAAEBkKOIBAACAyFDEAwAAAJGhiAcAAAAiQxEPAAAARIYiHgAAAIgMRTwAAAAQGYp4AAAAIDIU8QAAAEBkKOIBAACAyFDEAwAAAJGhiAcAAAAiQxEPAAAARIYiHgAAAIgMRTwAAAAQGYp4AAAAIDIU8QAAAEBkKOIBAACAyFDEAwAAAJGhiAcAAAAi067UEyiWvXv3unzvvfe6/Oijj1a/fu2119xYr169XJ41a5bLkydPdrlNG372AdA0du/e7fKMGTNcnj17do3n3nHHHXmvPX36dJe7dOlS4OwAoGls377d5a1bt9b7WitWrHC5vLy83tc6kMcff7z69W233ebGpk6dWrT7UI0CAAAAkaGIBwAAACJDEQ8AAABExkII+cbzDpbSnj17XD7zzDNdzva9N8TatWtdHjhwYNGu3Qis1BMAItRs1rqdO3e6fOedd7qc22tZm+zaZeaXhwsvvNDlbO/moEGDXO7cuXOd790EWOuAwjXZWpddy5544om8x69cudLlbH26ceNGl8vKyuo8l+y1smthoQq53iWXXOLyz372M5d79OhR2+1qvDg78QAAAEBkKOIBAACAyFDEAwAAAJGJpif+vffec3nkyJEub9q0yeWZM2e6/PWvf7369Zo1a9zYpEmTXP7oo49cHjx4sMvZPqxhw4a53KlTJ5UQfaJA4Uq21u3atcvl8ePHu/zKK6+4/OGHH9b52lu2bHH59ttvd7m2ntJsL+dDDz1U53s3AdY6oHCNtta9+eabLmc/q5h9znttitnHXsqe+Kzvf//7Lt9zzz21nUJPPAAAANBSUMQDAAAAkaGIBwAAACLTbHvi9+3b5/K5557r8uLFi12+4oorXJ4zZ47L7du3r/Fey5cvd3n48OF1naYkafbs2S5/+9vfLuj8IqNPFChcyda666+/3uVZs2a53LVrV5crKirqfa/KykqX33rrLZdHjBjhcrt27VxevXq1y4cffni951IErHVA4Rptrcv+HYnPP/+8QddrKT3xP/jBD1yeMWOGy/nq06rL1zTATjwAAAAQGYp4AAAAIDIU8QAAAEBk2tV+SGksW7bM5WwPfPfu3V2+7777XM72cubTq1evwiYHAEWS/fxPY9q7d6/LkydPdnn//v0uT5061eUS98ADaMaGDBnicva58U0p+9nGsWPHunzSSSe5/PHHH7v8/vvvF3S/7N8Hyl07Bw4c6Mbatm1b0LXzYSceAAAAiAxFPAAAABAZingAAAAgMs22J37JkiV5x3fu3Onye++953K2N6uYDjvsMJfHjx/faPcCgMby4x//2OVhw4a53KdPn6acDoCIPf/88y4PHTrU5W3bthV0vW7durl8ww03uHzzzTcXdL2WiJ14AAAAIDIU8QAAAEBkKOIBAACAyDTbnvgOHTrkHc8+z/iyyy5zeeHChS536dKlxmu98cYbBc3tlltucXnAgAEFnQ8AVa655hqX77//fpf37Nnjcvb5xf3793c539/IaN++vctjxozJO7fs8QBQk+zzzw8++GCXKyoq8p5/9913u3zllVe63LNnzwbMrmViJx4AAACIDEU8AAAAEBmKeAAAACAyFkLIN553sDHt3r3b5UGDBrm8devWgq43evToGseyzzatzeuvv+7y8OHDCzq/kVmpJwBEqGRrXdall17qcllZWd7jBw8e7PKECRNqPHbDhg0uz5071+VOnTq5/OKLL7p8/PHH551LE2OtAwrXZGvdq6++6vKpp55a0Pn33nuvy1dffXWD5xSpGtc6duIBAACAyFDEAwAAAJFptu00WS+88ILLo0aNcrmysrLJ5rJs2TKXjzvuuCa7dx3wK2agcM1mrXvnnXdczj52rby83OU1a9bUeK3s+m7ml4ds681TTz3lcrYNcdGiRTXeqwRY64DCNdlal13LTj75ZJe3b9+e9/wjjjjC5ex6lNv+17dvXzd20UUXuTxkyJD8k23eaKcBAAAAWgqKeAAAACAyFPEAAABAZKLpic/auHGjywsXLnR53rx5NZ6b7WHPPsYo219/1FFHubxq1ao6z7ME6BMFCtds17qs7ON39+7dW+9r7dixw+WhQ4e6fNZZZ7lMTzwQvZKtdTNnznT5pptuynt8bZ/paci9s4+r7NixY72v3QToiQcAAABaCop4AAAAIDIU8QAAAEBkou2Jb4hC/xTwlClTXH744YeLPqciok8UKFyLXOtqM336dJdnz57t8pgxY1x+5plnGn1OBWCtAwpXsrXu008/dfm0005zeeXKlS4Xsyc+e61rr73W5ezf48h9Bn0zQE88AAAA0FJQxAMAAACRoYgHAAAAItOu1BMohU8++aSg488555xGmgkANJ3PPvvM5WeffdblbM9pQ3pQASBXt27dXC4rK3N58+bNLk+YMMHl008/3eXy8vLq1/v27XNj2b/3k5X9+0DDhg1z+aqrrsp7fnPBTjwAAAAQGYp4AAAAIDIU8QAAAEBkWmVP/PLlyws6fty4cY00EwBoOgsWLHB59erVeY+fNm1aY04HQCt2wgkn5B3Prk89evRwecuWLdWvP/jgAzc2ceLEGo89kJdfftlleuIBAAAANAqKeAAAACAyFPEAAABAZCyEkG8872CsRo8e7fLixYvzHr9r1y6XO3bsWOwpFRMPdgYK1yLXuuxz4U855RSXs58P6tu3r8tLly51uU+fPkWcXYOx1gGFa5FrXdb8+fNdvuCCC1zO/g2MDh06uLxkyRKXjzvuuCLOrmA1rnXsxAMAAACRoYgHAAAAIkMRDwAAAESmVT4nHgBag+eee87l2v5Gxp133ulyM+uBBxCxFStWuHzIIYe4fNhhh9X72uvWrXP5gQceKOj8PXv2uLxs2TKXS9wTXyN24gEAAIDIUMQDAAAAkaGIBwAAACLTap4Tv2PHjurX2d6m9evX5z2X58QDLV6LWety9e/f3+UPP/zQ5aOPPtrll156yeVu3bo1zsSKg7UOKFyTrXXZPvNjjz3W5YqKCpeffvppl3v27Olytqe+rKys+vW8efPyziVb62afE3/55Ze7/OCDD+a9XhPjOfEAAABAS0ERDwAAAESGIh4AAACITKt5Tvzbb79d/bq2HngAiFVuX+nGjRvdWLYP9Cc/+YnLzbwHHkBEDjroIJdPOukklx977DGXv/zlL7tcWx97XccOZOzYsS5PnDixoPObC3biAQAAgMhQxAMAAACRoYgHAAAAItNqeuLXrVtX73O/8pWvuLxgwQKXu3fvXu9rA0AxrV27tvp1bT2lJ554YpPMCUDr06aN3yeeNGmSy9me+IYYP368y/Pnz897/JAhQ1weN25c0ebSlNiJBwAAACJDEQ8AAABEptW004wYMaLe5+7fv9/ltm3bNnQ6ANAo5syZU/260MeuAUBjOeaYY1weM2aMyxdffLHL1113nctnn322yzfeeGP16+HDh7ux8vLyvHM544wz8k82EuzEAwAAAJGhiAcAAAAiQxEPAAAARKbV9MT369ev+vW0adPc2IYNG1yePHmyy9lHF3Xu3LnIswOA4hg5cmT161WrVrmxW2+91eXevXs3yZwAINsTv2jRorzHV1ZWunzaaae5nH1MZK5Ro0YVOLs4sRMPAAAARIYiHgAAAIgMRTwAAAAQGcv+We6MvINolngwNFA41rr4sNYBhWOti0+Nax078QAAAEBkKOIBAACAyFDEAwAAAJGhiAcAAAAiQxEPAAAARIYiHgAAAIgMRTwAAAAQmdqeEw8AAACgmWEnHgAAAIgMRTwAAAAQGYp4AAAAIDIU8QAAAEBkKOIBAACAyFDEAwAAAJH5P+6E+gVTZQ4XAAAAAElFTkSuQmCC\n", 513 | "text/plain": [ 514 | "
" 515 | ] 516 | }, 517 | "metadata": { 518 | "needs_background": "light" 519 | }, 520 | "output_type": "display_data" 521 | }, 522 | { 523 | "data": { 524 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAvEAAACLCAYAAADoHsZvAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8li6FKAAATbUlEQVR4nO3de5CX1X3H8c+XiwtVLmIkqFURHNHGiUw0wQG1sVHYDYRaqgbwEiRKQVDrGAuoMbREgUEdrCBIO0oFFyVeqhEX8DLGwk40hNqKJiSyuAKyBImCN1Tg9I/n2e1zHtnf7u+6e3bfrxlnfp/fc/2x49nvnv0+Z805JwAAAADh6NDSNwAAAAAgOxTxAAAAQGAo4gEAAIDAUMQDAAAAgaGIBwAAAAJDEQ8AAAAEpl0X8WY2w8yWtfR9AEChmNkJZvaxmXXMsM/HZtavlPcFAC3NzKrM7EctfR+FElwRb2Yvm9kHZlbW0vcCAIVgZu+Y2Wdxcb3TzB4ysyNyOZdz7l3n3BHOuQPxuV82s6tT+xzhnKspxL0DQCbx+LbTzA5PvHe1mb1c5Ot+ZaLWOVfhnPuPYl63lIIq4s2sr6RzJTlJI1v0ZlIyzXoBQDP8wDl3hKRvSfq2pNta+H4AoFA6SbqhpW+irQmqiJd0paRfS1oiqeHXIWa2xMwWmNlKM/vIzF41s/6J7d8ws+fN7M/xT4O3JM55mJk9HB/3ppmdlTjutHgW68N428jUNRea2XNm9omk84v5wQG0D8657ZKqJJ1uZsea2TPx2PW2mV1Tv5+ZfcfM1pvZ3nhcuyd+v6+ZOTPrZGZ3KJr4mB/P8s+P93FmdrKZnW1mdclJCDP7OzP73/h1BzObZmabzWy3ma0ws16l/PcA0CbMlfQTM+uZ3mBmpyZqtE1mdmli21Fm9st4nPuNmf3czNYmtt9rZlvj7b81s3Pj98sl3SLph/HY9z/x+y/HvwUoi2u70xPnOjr+jWjvOI8ws9fj/arN7JtF+9fJUYhF/CPxf8PM7OuJbWMk/bOkIyW9LekOSTKzbpJekLRK0rGSTpb0YuK4kZIeldRT0jOS6r/JdZb0S0lrJPWWdJ2kR8xsQOLYsfF1uklaKwDIk5kdL+n7kv5b0nJJ2xSNXRdLutPMvhfveq+ke51z3SX1l7QifS7n3K2S/kvSlLiFZkpq+68lfSLpbxJvj5VUGb++XtJFkv46vocPJC0owMcE0L6sl/SypJ8k34xbbJ5XNOb0VlTL3W9m34h3WaBojOqjaPI23c/+G0kDJfWKz/ELM+vinFsl6U5Jj8Vj3xnJg5xzn0t6Mr5evUsl/co59ycz+5akByX9g6SjJD0g6ZnW1sodTBFvZudIOlHSCufcbyVtVvTNpt6TzrnXnHP7FRX5A+P3R0iqc87d7Zzb55z7yDn3auK4tc655+L+0aWS6r/QZ0s6QtJs59wXzrmXJD0r/wv+tHNunXPuoHNuX6E/M4B25T/N7ENFEwK/krRY0jmSpsZj1+uS/l3SFfH+X0o62cy+5pz7OC7Ic7Fc8bgWT3p8P35Pir6B3eqc2xZ/05sh6WIz65TjtQC0X7dLus7Mjk68N0LSO865h5xz+51zGyQ9oWic6Sjp7yX9zDn3qXPuLUleP7tzbplzbnd87N2SyiQlJ1szqZRf0yUnMK6R9IBz7lXn3IG4j/5zRbVhqxFMEa/op681zrn341wp/yeyusTrTxUV4JJ0vKKCvzHp47rE36COlbTVOXcwsb1W0nGJvLX5tw8AGV3knOvpnDvROXetojHoz865jxL7JMegH0s6RdLv418zj8jxupWSRsUzTKMkbXDO1cbbTpT0VPzr5A8l/U7SAUlfP/SpAODQnHMbFU2GTku8faKkQfVjTDzOXKZo5v1oRb30yVrLq7vM7CYz+52Z7YmP7SHpa828pZckdTWzQWZ2oqLJ36cS93VT6r6OVzQutxpBzKaYWVdFv+boaGb1RXeZpJ5mdkbjR0qKvuBjmtjnUN6TdLyZdUgU8idI+kNiH5fDeQGgOd6T1MvMuiUK+RMkbZck59wfJY0xsw6Kiu/HzeyoQ5wn4zjlnHvLzGolVcifiZKi8XO8c25dfh8FACRJP5O0QdLdcd6qqIXlwvSO8Uz8fkl/qf+vvY5PbD9X0lRJ35P0pnPuoJl9IMniXZoa+w6a2QpFNeJOSc8mxtqtku5wzt2R/UcsnVBm4i9SNPvzV4p+Uhoo6TRFvZ5XNnHss5L6mNk/xg8ydDOzQc245quK+rD+ycw6m9l3Jf1AUf88ABSVc26rpGpJs8ysS/xQ1Y8VtQvKzC43s6PjSYYP48MOHOJUOyU1tSZ8paL+9/Mk/SLx/iJJd8SzVPUPfv1trp8JQPvmnHtb0mOKxhspqtFOMbMr4lqrs5l928xOi9ucn5Q0w8z+wsxOlV/zdVNU5O+S1MnMbpfUPbF9p6S+8URHYyol/VDR7H9yAuPfJE2MZ+nNzA43s+Fxy2GrEUoR/yNJD8XrH9fV/6foIdTLlOE3CvFPVRcqKsDrJP1RzVhJxjn3haKHXiskvS/pfklXOud+n++HAYBmGiOpr6JZ+acU9YY+H28rl/SmmX2s6CHX0Y08m3Ovov7SD8zsXxu5znJJ35X0UqJlsf7YZyStMbOPFK0O1pxJEABozL9IOlxqqNGGShqtaJyrkzRHUbeFJE1R1CJTp+i5xeWKetMlabWilbz+oKjVcJ/8dpv6CYndZrbhUDcSPyP5iaI2marE++sV9cXPV/RA/9uSxuX2cYvHnKMjBAAAAK2bmc2R1Mc512b+6mo+QpmJBwAAQDsSryH/zbil5TuKWgqfauq49iKIB1sBAADQ7nRT1EJzrKQ/KXog9ukWvaNWhHYaAAAAIDC00wAAAACBoYgHAAAAAtNUTzy9NuGxpncBkMJYFx7GOiB7jHXhaXSsYyYeAAAACAxFPAAAABAYingAAAAgMBTxAAAAQGAo4gEAAIDAUMQDAAAAgaGIBwAAAAJDEQ8AAAAEhiIeAAAACAxFPAAAABAYingAAAAgMJ1a+gYAAACAUtm7d6+X16xZ4+VLLrnEy2bm5S5dunj5008/LeDdNR8z8QAAAEBgKOIBAACAwFDEAwAAAIGhJx4AAABtxoEDB7y8c+dOL1999dVeXr16tZc7dMg8x/3ll196+ZFHHvHyZZdd1qz7zBcz8QAAAEBgKOIBAACAwNBO0wzV1dV5HT948OAC3QkAtJz0WLh06VIvL1q0yMvl5eVerqqqKs6NAWjXduzY4eV58+Z5+a677iro9Q4ePOjlmpqagp6/uZiJBwAAAAJDEQ8AAAAEhiIeAAAACAw98ZJmz57t5enTpxf1euk+0Z/+9KcNr+mfB1As6T81vnHjRi+/8sorXs53LNy0aVNexwNAY7Zt29bwetmyZd62fHvgzzrrLC+fc845Xp44caKX+/fvn9f1csVMPAAAABAYingAAAAgMBTxAAAAQGDMOZdpe8aNoTKzop4/3SuVXjs5kya+Hs1R3A8HtE1tZqxL9r0PHDjQ27Zly5a8zl1ZWenlQYMGeTndF1rkdeIZ64DsBTPWff75517evn27l4cOHdrwOtux7dRTT/Xy7bff7uVhw4Z5uWfPnlmdv8AaHeuYiQcAAAACQxEPAAAABIYiHgAAAAhMm+2Jr6mp8XIh1/BM97zPmTPHy927d/dyNj34mzdv9nK/fv2yvDv6RIEctMuxLtPfrJC++ncr0uvMT5061cvp53/WrVuX8Xx5YqwDshfMWHfttdd6+YEHHsj5XJdeeqmXlyxZ4uWysrKcz10C9MQDAAAAbQVFPAAAABAYingAAAAgMG2mJ3758uVeHjt2bMHOnV4becyYMRn3T/eN9ujRI+dr57BuPH2iQPaCGevS40t6LfikAQMGeHnVqlVebqpnPX2t+++/38vTp0/3cvp5oYULFzZ6bwXAWAdkr9WMdel14G+88UYvL1682MuZ6qETTjjBy6NHj/byzJkzvdypU6dm32crQE88AAAA0FZQxAMAAACBoYgHAAAAAhNMU1C65/3WW2/18pYtW4p27draWi9XVFR4Od1nmq1kH+nNN9+c17kAtG3vv/++l9NjX/IZnuHDh3vbVq5c6eXTTz/dy9muA1/iHngAbcj27du9nM868Ok6LP08UFvFTDwAAAAQGIp4AAAAIDAU8QAAAEBgWm1P/KRJk7z8zjvvePmFF17wcv/+/Qt27c2bN2fcnl67NFsnnXSSl+fMmdPwunv37nmdGwDqpceT9N+4qKmp8fLkyZO9nO4zpQceQK7Sf2finnvuyer4Xr16eXnWrFkNr9N1VXvBTDwAAAAQGIp4AAAAIDAU8QAAAEBgzDmXaXvGjcVUXV3t5cGDB3s5vW782LFjc75WeXl5xu35rgOftm7dOi+nP1uerJAnA9qJFhvrspXuY7/gggu8nFwfuaqqKuOxTT1LlOw5laRp06Y1+z5LgLEOyF7JxrodO3Z4uW/fvl7ev39/xuPHjRvn5fnz53u5a9euOd9b2rvvvuvlXbt2eXnKlCle3rp1q5dHjRrl5blz5za8Lisry/f2Gh3rmIkHAAAAAkMRDwAAAASGIh4AAAAITKtdJ75Pnz5eNite+2Ohe97T65Wm17Tv169fQa8HoP1Ijx/JHnhJ2rRpU8PriooKb1tTY12Rn9cB0IZ99tlnXp43b56Xm+qB79Gjh5dvueUWL2fTA79v376M154xY4aXH3/8cS+ne96bsmDBAi8n++DvvPNOb1vnzp2zOncmzMQDAAAAgaGIBwAAAALTYktM7t2718sbN2708syZM71c6JaXQqqsrPRy+k+blxjLrgHZC2aJybT0crxDhgxpdN821urHWAdkr2hj3dChQ7384osvZty/d+/eXl67dq2Xm1oCN5OrrrrKyw8//HDO58rXc8895+Vhw4ZlewqWmAQAAADaCop4AAAAIDAU8QAAAEBgWmyJyYEDB3p5y5YtLXQnX+0TzfZehg8fXsjbAYBmW7p0ac7HBtYDD6AVO+OMM7zcVE/8Mccc4+V8euDPP/98L69fvz7ncxXCdddd1/A6fW+FxEw8AAAAEBiKeAAAACAwFPEAAABAYErWE59eF74le+DT0mslz50718uLFi3KePzUqVO9vHDhwsLcGACkTJo0ycvp8Sn5jM+ECRO8bdOnT/dyTU2Nl+mRB9Bcb731lpe3bduWcf+ysjIv33bbbVldb8+ePV5Ojmevvfaat23fvn0Zz9Wpk1/+zp4928tnnnmmly+//HIvv/fee14eP358o+c77LDDMt5LPpiJBwAAAAJDEQ8AAAAEhiIeAAAACEzJeuLT68IXWnl5uZcfe+yxRq+f7senDxRAa1VdXe3lTD3wkvT66683vF65cmXxbgxAu1ZbW+vlFStWZNw//TzPqFGjMu6/a9cuL1944YVefuONN5q6xUYtWbLEy6eccoqX77vvPi9v377dy+PGjfPy4sWLc76XfDATDwAAAASGIh4AAAAIDEU8AAAAEJiS9cQXel34WbNmeXnatGkZ91+2bFnD65kzZ3rb0mvYN7UuPAAUS3o8GjJkSMb9kz3waWPHjs14LM8DASiVuro6L6fXcu/SpYuXn376aS/n0wOflq4Dd+zY4eX0OHz99dd7Ob2ufEthJh4AAAAIDEU8AAAAEBiKeAAAACAw5pzLtD3jxqwuZJbX8em1kGtqanI+V7rXqUePHjmfS5L27Nnj5e7du+d1vjzl9w8NtE8FG+vyle5TTz9PtG7dOi8PHjy40eOzPTYwjHVA9go21lVVVXl5xIgRWR0/evRoL/fu3dvL6ecTv/jii6zOn40jjzzSy0888YSXzz77bC+XlZUV7V4OodGxjpl4AAAAIDAU8QAAAEBgKOIBAACAwJRsnfh8DRgwIK/jq6urG143te5yU9Jr1LdwDzyAgE2aNMnL6T72iRMnejndx57p+CaeeQKAnKXrssmTJ3t5wYIFGY9/9NFHC35PzXXTTTd5+YYbbvDycccdV8rbyRkz8QAAAEBgKOIBAACAwFDEAwAAAIEp2TrxFRUVXl61alVe50v3pdfW1np59erVXk73mWaSXpN+2bJlXm7layuzdjKQvZI1j6f/xkX//v0z7p8eo5v6Oxfl5eUNr9PrOLcxjHVA9oo21l1zzTVefvDBB4t1qSal15zfsGFDxu0dO3Ys+j3lgXXiAQAAgLaCIh4AAAAIDEU8AAAAEJiS9cRn2wdaSun++mnTprXQnRQEfaJA9krWE59e133RokVeTj+TM2HCBC9Pnz7dy8keeKnN98EnMdYB2SvaWLd+/Xov796928sjR4708v79+/O63vjx4xteX3zxxd62rl27evm8887L61otjJ54AAAAoK2giAcAAAACQxEPAAAABKZkPfFp6bWOp06d6uV0n2i2Jk6c6OUrrrii4XUrX+c9X/SJAtkrWU+8WX7/i6bHtoULF+Z1voAx1gHZK9lYh4KhJx4AAABoKyjiAQAAgMC0WDtNttJLVNbV1Xm5jbfIZINfMQPZC2aJycCXwC0kxjoge62mrkOz0U4DAAAAtBUU8QAAAEBgKOIBAACAwATTE49mo08UyB5jXXgY64DsMdaFh554AAAAoK2giAcAAAACQxEPAAAABIYiHgAAAAgMRTwAAAAQGIp4AAAAIDAU8QAAAEBgKOIBAACAwFDEAwAAAIGhiAcAAAACQxEPAAAABMaccy19DwAAAACywEw8AAAAEBiKeAAAACAwFPEAAABAYCjiAQAAgMBQxAMAAACBoYgHAAAAAvN/uEQ+0H/2m9oAAAAASUVORK5CYII=\n", 525 | "text/plain": [ 526 | "
" 527 | ] 528 | }, 529 | "metadata": { 530 | "needs_background": "light" 531 | }, 532 | "output_type": "display_data" 533 | }, 534 | { 535 | "data": { 536 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAvEAAACLCAYAAADoHsZvAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8li6FKAAAUCUlEQVR4nO3de5CU1ZnH8d/DAMM1oNHIxRGiKBETE6k1SAolKyWaiBpcY8CgktUU1BbKWqTWSEyMBlgtxRQmkY1r4WWNrIhhgxhYw6opchOUFaODRCRyiRlUAsoAM1zm7B/vy2w/r0zP9NBzOT3fT5VV769P99tnpOrMM2eePmMhBAEAAACIR6e2ngAAAACAwlDEAwAAAJGhiAcAAAAiQxEPAAAARIYiHgAAAIgMRTwAAAAQmQ5dxJvZ983ssbaeBwAUi5mdZGbVZlaW5znVZnZya84LANqamS03s2vbeh7FEl0Rb2YvmNlOMytv67kAQDGY2dtmti8trreb2UNm1qs59wohbAkh9AohHErv/YKZXZ95Tq8QwqZizB0A8knXt+1m1jPnsevN7IUWft+PbNSGEL4UQnikJd+3NUVVxJvZYEnnSgqSLm3TyWTk2/UCgCa4JITQS9JwSWdLurWN5wMAxdJZ0vS2nkSpiaqIl3SNpD9IelhS/a9DzOxhM/uJmT1jZrvN7EUzOyVn/Awz+5WZ/S39aXBmzj27mtmj6eteN7O/y3nd6eku1q507NLMe843s1+a2R5Jf9+SXziAjiGE8BdJyyV92swGmNnSdO3aaGbfPPw8M/u8mb1kZh+m69q96eODzSyYWWczm61k4+PH6S7/j9PnBDMbYmbnmFlV7iaEmY03s1fT605m9m0ze8vMdpjZIjM7tjX/fwAoCXdL+paZ9c0OmNmncmq0DWZ2Zc7Yx83s6XSdW2Nms8zsNznj88xsazr+spmdmz5+kaSZkr6Wrn3r0sdfSH8LUJ7Wdp/Oudfx6W9EP5HmcWb2Svq835nZmS32f6eZYizif5b+d6GZnZAzNlHS7ZKOkbRR0mxJMrPeklZKWiFpgKQhkv4n53WXSvpPSX0lLZV0+JtcF0lPS3pW0ick3SDpZ2Y2NOe1V6Xv01vSbwQAR8nMKiR9WdL/SlooaZuStesKSXPMbEz61HmS5oUQPibpFEmLsvcKIXxH0ipJ09IWmmmZ8T9I2iPp/JyHr5L0eHp9o6SvSBqdzmGnpJ8U4csE0LG8JOkFSd/KfTBtsfmVkjXnE0pqufvN7Iz0KT9Rskb1U7J5m+1nXyPpc5KOTe/xpJl1CyGskDRH0hPp2vfZ3BeFEGol/Tx9v8OulPTrEMK7ZjZc0gJJUyR9XNJPJS1tb63c0RTxZjZK0iBJi0IIL0t6S8k3m8N+HkJYHUI4qKTI/1z6+DhJVSGEuSGEmhDC7hDCizmv+00I4Zdp/+h/SDr8D32OpF6S7gwh7A8hPCdpmfw/+C9CCL8NIdSFEGqK/TUD6FD+y8x2KdkQ+LWkBySNknRzuna9IulBSVenzz8gaYiZHRdCqE4L8uZYqHRdSzc9vpw+JiXfwL4TQtiWftP7vqQrzKxzM98LQMf1PUk3mNnxOY+Nk/R2COGhEMLBEMJaSU8pWWfKJP2DpNtCCHtDCJWSXD97COGxEMKO9LVzJZVLyt1szedx+ZoudwPjm5J+GkJ4MYRwKO2jr1VSG7Yb0RTxSn76ejaE8H6aH5f/iawq53qvkgJckiqUFPwNyb6uW/oNaoCkrSGEupzxzZIG5uStTZ8+AOT1lRBC3xDCoBDCPylZg/4WQtid85zcNeg6SadJeiP9NfO4Zr7v45IuT3eYLpe0NoSwOR0bJGlJ+uvkXZLWSzok6YQj3woAjiyE8JqSzdBv5zw8SNKIw2tMus58XcnO+/FKeulzay1Xd5nZDDNbb2YfpK/tI+m4Jk7pOUndzWyEmQ1Ssvm7JGdeMzLzqlCyLrcbUeymmFl3Jb/mKDOzw0V3uaS+ZvbZhl8pKfkHn9jIc47kHUkVZtYpp5A/SdKfcp4TmnFfAGiKdyQda2a9cwr5kyT9RZJCCG9KmmhmnZQU34vN7ONHuE/edSqEUGlmmyV9SX4nSkrWz38MIfz26L4UAJAk3SZpraS5ad6qpIXlguwT0534g5JO1P/XXhU54+dKulnSGEmvhxDqzGynJEuf0tjaV2dmi5TUiNslLctZa7dKmh1CmF34l9h6YtmJ/4qS3Z9hSn5S+pyk05X0el7TyGuXSepnZv+cfpCht5mNaMJ7vqikD+tfzKyLmX1R0iVK+ucBoEWFELZK+p2kfzWzbumHqq5T0i4oM5tkZsenmwy70pcdOsKttktq7Ez4x5X0v58n6cmcx/9N0ux0l+rwB78ua+7XBKBjCyFslPSEkvVGSmq008zs6rTW6mJmZ5vZ6Wmb888lfd/MepjZp+Rrvt5Kivz3JHU2s+9J+ljO+HZJg9ONjoY8LulrSnb/czcw/l3S1HSX3sysp5ldnLYcthuxFPHXSnooPf+46vB/Sj6E+nXl+Y1C+lPVBUoK8CpJb6oJJ8mEEPYr+dDrlyS9L+l+SdeEEN442i8GAJpooqTBSnbllyjpDf1VOnaRpNfNrFrJh1wnNPDZnHlK+kt3mtl9DbzPQklflPRcTsvi4dculfSsme1WcjpYUzZBAKAhd0jqKdXXaGMlTVCyzlVJuktJt4UkTVPSIlOl5HOLC5X0pkvSfys5yetPSloNa+TbbQ5vSOwws7VHmkj6Gck9Stpkluc8/pKSvvgfK/lA/0ZJk5v35bYcC4GOEAAAALRvZnaXpH4hhJL5q6tHI5adeAAAAHQg6RnyZ6YtLZ9X0lK4pLHXdRRRfLAVAAAAHU5vJS00AyS9q+QDsb9o0xm1I7TTAAAAAJGhnQYAAACIDEU8AAAAEJnGeuLptYmPNf4UABmsdfFhrQMKx1oXnwbXOnbiAQAAgMhQxAMAAACRoYgHAAAAIsM58ZLeffddl++55x6XzzrrLJcvv/xyl8vLywUAAAC0FnbiAQAAgMhQxAMAAACRoYgHAAAAImMh5D0ytEOcJ1pbW+vy4MGDXd6+fbvL69evd3no0KEtMq9m4uxkoHAdYq0rMax1QOFY6+LDOfEAAABAqaCIBwAAACJDEQ8AAABEhp74I9i7d6/LAwcOdLl///4uV1ZWtvicCkCfKFC4DrnWRY61Digca1186IkHAAAASgVFPAAAABAZingAAAAgMp3begLtUY8ePVzu2rWry1u2bGnN6QCIiJlvX+zUqbh7Jeeff77LPXv2dHnq1Kn112PGjHFjXbp0KepcAABth514AAAAIDIU8QAAAEBkOGKyCU444QSX9+zZ43J1dXVrTqcxHLsGFK5oa11ZWZnLxW6nKcSoUaNc7tevn8tz5851ObvWZb+Wdoa1Digcdd0RbNq0yeXly5e7vHLlSpcLWde/8Y1vuDxu3LgCZ8cRkwAAAEDJoIgHAAAAIkMRDwAAAESGnvgmoCceKHlFW+umT5/u8v3331+sW7e4Bx980OVrr722jWbSJKx1QOGo6yStW7fO5ZEjR7pcW1vrcrZWzh4lnD3qN9/niXbu3NnkeR5+u4YG2IkHAAAAIkMRDwAAAESGIh4AAACIDD3xR1BTU+PygAEDXL799ttdvuGGG1p8TgWgTxQoXNHWurq6Opd37drl8pw5c47q/vPnz3d5//79DT73oosucnnFihV5753t8zzmmGNcfumll1yuqKjIe78WxloHFK5k6rrcWm3WrFluLJuzbrnlFpeznwcaPXq0yzNmzHA5u1Z+8pOfdLlLly4Nvvexxx6bd25HQE88AAAAUCoo4gEAAIDIUMQDAAAAkaEnXh89/zPbczpt2jSX//znP7s8aNCglplY89AnChQumrVu06ZNLmd78HP169fP5Wxf54IFCwp67+xa99xzz7l80kknFXS/o8RaBxQumrWuMRs2bKi/HjZsmBvbtm2by/3793d5ypQpLt91110u9+3btxhTLBZ64gEAAIBSQREPAAAARIYiHgAAAIgMPfH66LnwPXr0cLlPnz4uv/XWWy4348zPlkSfKFC4DrHWZc+Uf+qpp1y+6aabXN6xY0fe+2V75J9//nmXW/gcedY6oHDRrnXZPvevfvWr9dcHDx50Y2vWrGmVObUSeuIBAACAUkERDwAAAESGIh4AAACITOe2nkB7kO2Jz5o5c6bL7awHHgCapGvXri5PnDjR5ZEjR7o8Z84clx966CGXN2/e7PLZZ5/tclVVVbPmCQBZ2b/hs2/fvvrr3//+9609nXaBnXgAAAAgMhTxAAAAQGQo4gEAAIDIdMhz4vfs2ePymDFjXF67dq3LH3zwgcvdu3dvmYkVB2cnA4UrybXuaNXW1rqcPRc+e458p05+X+iBBx5wedKkSfXXZWVlRzs91jqgcNGudXPnznU5tw9+8eLFrT2d1sQ58QAAAECpoIgHAAAAIkMRDwAAAESmQ54T/8Ybb7i8evVqlxcuXOhyO++BB4AWUV5e7nK2J3Xy5Mku19XVuXz99de7PHbs2Prr/v37F2GGADqKbdu2uZzbE//222+7scGDB7fCjNoeO/EAAABAZCjiAQAAgMhQxAMAAACRifac+Jqamrzj3bp1czn3vONLLrnEja1Zs8bl999/3+UinGfcmjg7GShcu13r2pMDBw64/Oqrr7p8zjnn5H39FVdcUX+d/exRM7DWAYWLdq175ZVXXL766qvrr4cMGeLGlixZ0ipzaiWcEw8AAACUCop4AAAAIDLt9ojJgwcPupz9893z5s3L+/oTTzzR5aqqqvrr9evXu7GbbrrJ5cjaZwCgVXTp0sXlYcOGuTx+/HiXs7/Sfv755+uvX3/9dTd2xhlnFGOKAEpUdo2YMGFC/fWtt97qxn70ox+5PGXKFJe7du1a5Nm1DXbiAQAAgMhQxAMAAACRoYgHAAAAItNueuJXrVrl8pVXXulyZWWly1u2bHE5++d4ly1b5vKHH35Yf33qqae6sdmzZxc2WQBoJ7LH7T722GPNvtekSZNc7tzZf4t48803XR4wYIDL27dvz3v/HTt21F9n12x64gHkk/1MzowZM+qvs8fdTp8+3eXq6mqXb7nlliLPrm2wEw8AAABEhiIeAAAAiAxFPAAAABCZNuuJP3TokMtz5sxxObeHXZLM/F+dvfPOO11es2aNyytXrmzwftneqN27d7vcrVu3hqYNAHm99tprLv/1r391edq0aUV9vwMHDri8devW+uv+/fu7sbq6OpezPey557hLUo8ePVzOnvves2dPl9955528c+3bt2+DcwOAQuTWatnPAu3fv9/lH/zgBy5ffPHFLp955plFnl3rYCceAAAAiAxFPAAAABAZingAAAAgMhZCyDeed/BoZHvec3slJX/+pyTdfffdLmd7OydMmODy4sWLXS4rK6u/zvbjn3766S7/8Y9/dLlTp6h+1rHGnwIgo2hrXe5aI7Xt+lFRUeFydr3PntXe0u699976682bN7uxe+65p9DbsdYBhWuxuq49Wbt2rctjx451+ZlnnnF5xIgRLT6no9DgWhdVdQoAAACAIh4AAACIDkU8AAAAEJk2Oye+Me+9957LNTU1LmfP+Myeb3zddde5nNtvOWvWLDc2d+5cl0855RSXs2fQH3fccQ1NGwDajcmTJ7uc/axQazv33HPrry+99NI2nAmAo7Vv3z6XN27c6PJnPvOZ1pyOM3z4cJdHjx7t8sMPP+xyO++JbxA78QAAAEBkKOIBAACAyFDEAwAAAJFps574rl27ujxy5EiXH330UZefffZZl6uqqly+5pprXM6eK9+nT5/669tuu82N7d271+X58+e7PHTo0AbvdSTZXqzseaU//OEPXb7sssvy3g8AGpLtO839Oxf9+vVzY9XV1S5ne+azfaKFeuSRR1y+8MILXc5dOzt3brcfyQLQBBMnTnT5C1/4gssnn3yyyz179mzxOR2WrfOWLFnicmVlZavNpSWxEw8AAABEhiIeAAAAiAxFPAAAABAZCyHkG887WEx79uxxedSoUS6vW7fO5ZkzZ7p88803u9y7d+8mv3ddXZ3Lq1atcnnZsmVNvpf00f78bt26uXzqqacWdL8CWUveHChRRVvrevXqlXe8trY27/jAgQNdzq6FTzzxhMuTJk1q8F5btmxxeerUqS5fddVVeefSmPLycpdXrFjhcnbuRcZaBxSuaGvdwoULXc6uRWeddZbL9913n8sVFRV5c2M2bdpUf718+XI3lv1cZO7fCpKk8ePHu1xWVlbQe7eyBtc6duIBAACAyFDEAwAAAJGhiAcAAAAi02564lE09IkChWuxtW737t0uz5s3z+X169e7vGjRoqK994gRI1zOrverV68u6H7nnXeey0uXLnW5Nc+BFmsd0BxFW+sa+zzhmDFj/Btn1p/u3bu7PGTIEJez69e+fftcfvrpp+uvL7jgAjd2xx13uJz9LGI774HPoiceAAAAKBUU8QAAAEBkKOIBAACAyNATX3roEwUK12ZrXU1Njcs7d+7M+/wbb7zR5erqapc3bNhQf71161Y3dtppp+V97+x49u9tLFiwwOXGzsRvYax1QOFaba17+eWXXX7yySdd3rhxY0H3y/bQf/e7362/zq5dJYaeeAAAAKBUUMQDAAAAkaGIBwAAACJDT3zpoU8UKFzJrHW5ffCVlZVubPjw4S5ne+IrKipabmLFx1oHFK5k1roOhJ54AAAAoFRQxAMAAACRoZ2m9PArZqBwrHXxYa0DCsdaFx/aaQAAAIBSQREPAAAARIYiHgAAAIgMRTwAAAAQGYp4AAAAIDIU8QAAAEBkKOIBAACAyFDEAwAAAJGhiAcAAAAiQxEPAAAARIYiHgAAAIgMRTwAAAAQGYp4AAAAIDIU8QAAAEBkKOIBAACAyFgIoa3nAAAAAKAA7MQDAAAAkaGIBwAAACJDEQ8AAABEhiIeAAAAiAxFPAAAABAZingAAAAgMv8He786xB9OhE8AAAAASUVORK5CYII=\n", 537 | "text/plain": [ 538 | "
" 539 | ] 540 | }, 541 | "metadata": { 542 | "needs_background": "light" 543 | }, 544 | "output_type": "display_data" 545 | } 546 | ], 547 | "source": [ 548 | "triplets = get_batch_random(2)\n", 549 | "print(\"Checking batch width, should be 3 : \",len(triplets))\n", 550 | "print(\"Shapes in the batch A:{0} P:{1} N:{2}\".format(triplets[0].shape, triplets[1].shape, triplets[2].shape))\n", 551 | "drawTriplets(triplets)\n", 552 | "hardtriplets = get_batch_hard(50,1,1,network)\n", 553 | "print(\"Shapes in the hardbatch A:{0} P:{1} N:{2}\".format(hardtriplets[0].shape, hardtriplets[1].shape, hardtriplets[2].shape))\n", 554 | "drawTriplets(hardtriplets)" 555 | ] 556 | }, 557 | { 558 | "cell_type": "code", 559 | "execution_count": 14, 560 | "metadata": {}, 561 | "outputs": [], 562 | "source": [ 563 | "# Hyper parameters\n", 564 | "evaluate_every = 1000 # interval for evaluating on one-shot tasks\n", 565 | "batch_size = 32\n", 566 | "n_iter = 80000 # No. of training iterations\n", 567 | "n_val = 250 # how many one-shot tasks to validate on" 568 | ] 569 | }, 570 | { 571 | "cell_type": "code", 572 | "execution_count": 15, 573 | "metadata": {}, 574 | "outputs": [], 575 | "source": [ 576 | "def compute_probs(network,X,Y):\n", 577 | " '''\n", 578 | " Input\n", 579 | " network : current NN to compute embeddings\n", 580 | " X : tensor of shape (m,w,h,1) containing pics to evaluate\n", 581 | " Y : tensor of shape (m,) containing true class\n", 582 | " \n", 583 | " Returns\n", 584 | " probs : array of shape (m,m) containing distances\n", 585 | " \n", 586 | " '''\n", 587 | " m = X.shape[0]\n", 588 | " nbevaluation = int(m*(m-1)/2)\n", 589 | " probs = np.zeros((nbevaluation))\n", 590 | " y = np.zeros((nbevaluation))\n", 591 | " \n", 592 | " #Compute all embeddings for all pics with current network\n", 593 | " embeddings = network.predict(X)\n", 594 | " \n", 595 | " size_embedding = embeddings.shape[1]\n", 596 | " \n", 597 | " #For each pics of our dataset\n", 598 | " k = 0\n", 599 | " for i in range(m):\n", 600 | " #Against all other images\n", 601 | " for j in range(i+1,m):\n", 602 | " #compute the probability of being the right decision : it should be 1 for right class, 0 for all other classes\n", 603 | " probs[k] = -compute_dist(embeddings[i,:],embeddings[j,:])\n", 604 | " if (Y[i]==Y[j]):\n", 605 | " y[k] = 1\n", 606 | " #print(\"{3}:{0} vs {1} : {2}\\tSAME\".format(i,j,probs[k],k))\n", 607 | " else:\n", 608 | " y[k] = 0\n", 609 | " #print(\"{3}:{0} vs {1} : \\t\\t\\t{2}\\tDIFF\".format(i,j,probs[k],k))\n", 610 | " k += 1\n", 611 | " return probs,y\n", 612 | "#probs,yprobs = compute_probs(network,x_test_origin[:10,:,:,:],y_test_origin[:10])\n", 613 | "\n", 614 | "def compute_metrics(probs,yprobs):\n", 615 | " '''\n", 616 | " Returns\n", 617 | " fpr : Increasing false positive rates such that element i is the false positive rate of predictions with score >= thresholds[i]\n", 618 | " tpr : Increasing true positive rates such that element i is the true positive rate of predictions with score >= thresholds[i].\n", 619 | " thresholds : Decreasing thresholds on the decision function used to compute fpr and tpr. thresholds[0] represents no instances being predicted and is arbitrarily set to max(y_score) + 1\n", 620 | " auc : Area Under the ROC Curve metric\n", 621 | " '''\n", 622 | " # calculate AUC\n", 623 | " auc = roc_auc_score(yprobs, probs)\n", 624 | " # calculate roc curve\n", 625 | " fpr, tpr, thresholds = roc_curve(yprobs, probs)\n", 626 | " \n", 627 | " return fpr, tpr, thresholds,auc\n", 628 | "\n", 629 | "def compute_interdist(network):\n", 630 | " '''\n", 631 | " Computes sum of distances between all classes embeddings on our reference test image: \n", 632 | " d(0,1) + d(0,2) + ... + d(0,9) + d(1,2) + d(1,3) + ... d(8,9)\n", 633 | " A good model should have a large distance between all theses embeddings\n", 634 | " \n", 635 | " Returns:\n", 636 | " array of shape (nb_classes,nb_classes) \n", 637 | " '''\n", 638 | " res = np.zeros((nb_classes,nb_classes))\n", 639 | " \n", 640 | " ref_images = np.zeros((nb_classes,img_rows,img_cols,1))\n", 641 | " \n", 642 | " #generates embeddings for reference images\n", 643 | " for i in range(nb_classes):\n", 644 | " ref_images[i,:,:,:] = dataset_test[i][0,:,:,:]\n", 645 | " ref_embeddings = network.predict(ref_images)\n", 646 | " \n", 647 | " for i in range(nb_classes):\n", 648 | " for j in range(nb_classes):\n", 649 | " res[i,j] = compute_dist(ref_embeddings[i],ref_embeddings[j])\n", 650 | " return res\n", 651 | "\n", 652 | "def draw_interdist(network,n_iteration):\n", 653 | " interdist = compute_interdist(network)\n", 654 | " \n", 655 | " data = []\n", 656 | " for i in range(nb_classes):\n", 657 | " data.append(np.delete(interdist[i,:],[i]))\n", 658 | "\n", 659 | " fig, ax = plt.subplots()\n", 660 | " ax.set_title('Evaluating embeddings distance from each other after {0} iterations'.format(n_iteration))\n", 661 | " ax.set_ylim([0,3])\n", 662 | " plt.xlabel('Classes')\n", 663 | " plt.ylabel('Distance')\n", 664 | " ax.boxplot(data,showfliers=False,showbox=True)\n", 665 | " locs, labels = plt.xticks()\n", 666 | " plt.xticks(locs,np.arange(nb_classes))\n", 667 | "\n", 668 | " plt.show()\n", 669 | " \n", 670 | "def find_nearest(array,value):\n", 671 | " idx = np.searchsorted(array, value, side=\"left\")\n", 672 | " if idx > 0 and (idx == len(array) or math.fabs(value - array[idx-1]) < math.fabs(value - array[idx])):\n", 673 | " return array[idx-1],idx-1\n", 674 | " else:\n", 675 | " return array[idx],idx\n", 676 | " \n", 677 | "def draw_roc(fpr, tpr,thresholds):\n", 678 | " #find threshold\n", 679 | " targetfpr=1e-3\n", 680 | " _, idx = find_nearest(fpr,targetfpr)\n", 681 | " threshold = thresholds[idx]\n", 682 | " recall = tpr[idx]\n", 683 | " \n", 684 | " \n", 685 | " # plot no skill\n", 686 | " plt.plot([0, 1], [0, 1], linestyle='--')\n", 687 | " # plot the roc curve for the model\n", 688 | " plt.plot(fpr, tpr, marker='.')\n", 689 | " plt.title('AUC: {0:.3f}\\nSensitivity : {2:.1%} @FPR={1:.0e}\\nThreshold={3})'.format(auc,targetfpr,recall,abs(threshold) ))\n", 690 | " # show the plot\n", 691 | " plt.show()" 692 | ] 693 | }, 694 | { 695 | "cell_type": "code", 696 | "execution_count": 16, 697 | "metadata": {}, 698 | "outputs": [ 699 | { 700 | "data": { 701 | "image/png": "\n", 702 | "text/plain": [ 703 | "
" 704 | ] 705 | }, 706 | "metadata": { 707 | "needs_background": "light" 708 | }, 709 | "output_type": "display_data" 710 | }, 711 | { 712 | "ename": "NameError", 713 | "evalue": "name 'dist' is not defined", 714 | "output_type": "error", 715 | "traceback": [ 716 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 717 | "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", 718 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mfpr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtpr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mthresholds\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mauc\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcompute_metrics\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mprobs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0myprob\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mdraw_roc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfpr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtpr\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mthresholds\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0mdraw_interdist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnetwork\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mn_iteration\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", 719 | "\u001b[0;32m\u001b[0m in \u001b[0;36mdraw_interdist\u001b[0;34m(network, n_iteration)\u001b[0m\n\u001b[1;32m 76\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 77\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mdraw_interdist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnetwork\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mn_iteration\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 78\u001b[0;31m \u001b[0minterdist\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcompute_interdist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnetwork\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 79\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 80\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 720 | "\u001b[0;32m\u001b[0m in \u001b[0;36mcompute_interdist\u001b[0;34m(network)\u001b[0m\n\u001b[1;32m 72\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnb_classes\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 73\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mj\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnb_classes\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 74\u001b[0;31m \u001b[0mres\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mj\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mref_embeddings\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mref_embeddings\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mj\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 75\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mres\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 76\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 721 | "\u001b[0;31mNameError\u001b[0m: name 'dist' is not defined" 722 | ] 723 | } 724 | ], 725 | "source": [ 726 | "#Testing on an untrained network\n", 727 | "probs,yprob = compute_probs(network,x_test_origin[:500,:,:,:],y_test_origin[:500])\n", 728 | "fpr, tpr, thresholds,auc = compute_metrics(probs,yprob)\n", 729 | "draw_roc(fpr, tpr,thresholds)\n", 730 | "draw_interdist(network,n_iteration)" 731 | ] 732 | }, 733 | { 734 | "cell_type": "code", 735 | "execution_count": null, 736 | "metadata": {}, 737 | "outputs": [], 738 | "source": [] 739 | } 740 | ], 741 | "metadata": { 742 | "kernelspec": { 743 | "display_name": "Python 3", 744 | "language": "python", 745 | "name": "python3" 746 | }, 747 | "language_info": { 748 | "codemirror_mode": { 749 | "name": "ipython", 750 | "version": 3 751 | }, 752 | "file_extension": ".py", 753 | "mimetype": "text/x-python", 754 | "name": "python", 755 | "nbconvert_exporter": "python", 756 | "pygments_lexer": "ipython3", 757 | "version": "3.7.6" 758 | } 759 | }, 760 | "nbformat": 4, 761 | "nbformat_minor": 4 762 | } 763 | --------------------------------------------------------------------------------