├── LICENSE.txt ├── README.md ├── build └── lib │ └── categorical_embedder │ └── __init__.py ├── categorical_embedder.egg-info ├── PKG-INFO ├── SOURCES.txt ├── dependency_links.txt ├── requires.txt └── top_level.txt ├── categorical_embedder └── __init__.py ├── demo └── demo.gif ├── dist ├── categorical_embedder-0.1-py3-none-any.whl └── categorical_embedder-0.1.tar.gz ├── example_notebook ├── Example Notebook.ipynb └── HR_Attrition_Data.csv └── setup.py /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2018 The Python Packaging Authority 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Categorical Embedder 2 | 3 | **Categorical Embedder** is a python package that let's you convert your categorical variables into numeric via Neural Networks 4 | 5 | ![Categorical Embedder Demo](demo/demo.gif) 6 | 7 | ## Installation 8 | 9 | `pip install categorical_embedder` 10 | 11 | ## Example 12 | ```py 13 | import categorical_embedder as ce 14 | from sklearn.model_selection import train_test_split 15 | 16 | df = pd.read_csv('HR_Attrition_Data.csv') 17 | X = df.drop(['employee_id', 'is_promoted'], axis=1) 18 | y = df['is_promoted'] 19 | 20 | embedding_info = ce.get_embedding_info(X) 21 | X_encoded,encoders = ce.get_label_encoded_data(X) 22 | 23 | X_train, X_test, y_train, y_test = train_test_split(X_encoded,y) 24 | 25 | embeddings = ce.get_embeddings(X_train, y_train, categorical_embedding_info=embedding_info, 26 | is_classification=True, epochs=100,batch_size=256) 27 | ``` 28 | A more detailed [Jupyter Notebook](https://github.com/Shivanandroy/CategoricalEmbedder/blob/master/example_notebook/Example%20Notebook.ipynb ) can be found here 29 | 30 | > What's inside **Categorical Embedder** ? 31 | * `ce.get_embedding_info(data,categorical_variables=None)`: This function identifies all categorical variables in the data, determines its embedding size. Embedding size of the categorical variables are determined by minimum of 50 or half of the no. of its unique values i.e. embedding size of a column = Min(50, # unique values in that column) 32 | One can pass explicit list of categorical variables in `categorical_variables` parameter. If `None`, this function automatically takes all the variables with data type `object` 33 | * `ce.get_label_encoded_data(data, categorical_variables=None)`: This function label encodes (integer encoding) all the categorical variables using sklearn.preprocessing.LabelEncoder and returns a label encoded dataframe for training. Keras/tensorflow or any other deep learning library would expect the data to be in this format. 34 | * `ce.get_embeddings(X_train, y_train, categorical_embedding_info=embedding_info, is_classification=True, epochs=100,batch_size=256)`: This function trains a shallow neural networks and returns embeddings of categorical variables. Under the hood, It is a 2 layer neural network architecture with 1000 and 500 neurons with 'ReLU' activation. It takes 4 required inputs - `X_train`, `y_train`, `categorical_embedding_info`:output of get_embedding_info function and `is_classification`: `True` for classification tasks; `False` for regression tasks. 35 | 36 | For classification: `loss = 'binary_crossentropy'; metrics = 'accuracy'` and for regression: `loss = 'mean_squared_error'; metrics = 'r2'` 37 | 38 | ## Dependencies 39 | ```numpy 40 | pandas 41 | scikit-learn 42 | tensorflow 43 | keras 44 | tqdm 45 | keras-tqdm 46 | ``` 47 | ## Contributors 48 | * [Prakash Behera](https://github.com/Praks07) 49 | 50 | -------------------------------------------------------------------------------- /build/lib/categorical_embedder/__init__.py: -------------------------------------------------------------------------------- 1 | # Necessary imports: 2 | from keras.models import Sequential, Model, model_from_json 3 | from keras.layers import Dense, Dropout, Embedding, Activation, Input, concatenate, Reshape, Flatten 4 | from keras.layers.normalization import BatchNormalization 5 | from keras.layers.advanced_activations import PReLU 6 | from keras.optimizers import Adam 7 | from keras_tqdm import TQDMNotebookCallback 8 | from keras import backend as K 9 | from sklearn.preprocessing import LabelEncoder 10 | from sklearn.utils.validation import check_is_fitted 11 | from sklearn.utils import column_or_1d 12 | from sklearn.model_selection import train_test_split 13 | import numpy as np 14 | import pandas as pd 15 | import warnings 16 | from tqdm import tqdm_notebook 17 | warnings.filterwarnings("ignore") 18 | 19 | # Helper functions: 20 | 21 | class __LabelEncoder__(LabelEncoder): 22 | 23 | def transform(self, y): 24 | 25 | check_is_fitted(self, 'classes_') 26 | y = column_or_1d(y, warn=True) 27 | 28 | unseen = len(self.classes_) 29 | 30 | e = np.array([ 31 | np.searchsorted(self.classes_, x) 32 | if x in self.classes_ else unseen 33 | for x in y 34 | ]) 35 | 36 | if unseen in e: 37 | self.classes_ = np.array(self.classes_.tolist() + ['unseen']) 38 | 39 | return e 40 | 41 | 42 | def get_embedding_info(data, categorical_variables=None): 43 | ''' 44 | this function identifies categorical variables and its embedding size 45 | 46 | :data: input data [dataframe] 47 | :categorical_variables: list of categorical_variables [default: None] 48 | if None, it automatically takes the variables with data type 'object' 49 | 50 | embedding size of categorical variables are determined by minimum of 50 or half of the no. of its unique values. 51 | i.e. embedding size of a column = Min(50, # unique values of that column) 52 | ''' 53 | if categorical_variables is None: 54 | categorical_variables = data.select_dtypes(include='object').columns 55 | 56 | return {col:(data[col].nunique(),min(50,(data[col].nunique()+ 1) //2)) for col in categorical_variables} 57 | 58 | 59 | def get_label_encoded_data(data, categorical_variables=None): 60 | ''' 61 | this function label encodes all the categorical variables using sklearn.preprocessing.labelencoder 62 | and returns a label encoded dataframe for training 63 | 64 | :data: input data [dataframe] 65 | :categorical_variables: list of categorical_variables [Default: None] 66 | if None, it automatically takes the variables with data type 'object' 67 | ''' 68 | encoders = {} 69 | 70 | df = data.copy() 71 | 72 | if categorical_variables is None: 73 | categorical_variables = [col for col in df.columns if df[col].dtype == 'object'] 74 | 75 | for var in categorical_variables: 76 | #print(var) 77 | encoders[var] = __LabelEncoder__() 78 | df.loc[:, var] = encoders[var].fit_transform(df[var]) 79 | 80 | return df, encoders 81 | 82 | 83 | 84 | def r2(y_true, y_pred): 85 | SS_res = K.sum(K.square(y_true - y_pred)) 86 | SS_tot = K.sum(K.square(y_true - K.mean(y_true))) 87 | return (1 - SS_res / (SS_tot + K.epsilon())) 88 | 89 | def precision(y_true, y_pred): 90 | true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1))) 91 | predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1))) 92 | precision = true_positives / (predicted_positives + K.epsilon()) 93 | return precision 94 | 95 | 96 | def recall(y_true, y_pred): 97 | true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1))) 98 | possible_positives = K.sum(K.round(K.clip(y_true, 0, 1))) 99 | recall = true_positives / (possible_positives + K.epsilon()) 100 | return recall 101 | 102 | 103 | # Main function: 104 | 105 | def get_embeddings(X_train, y_train, categorical_embedding_info, is_classification, epochs=100, batch_size=256): 106 | ''' 107 | this function trains a shallow neural networks and returns embeddings of categorical variables 108 | 109 | :X_train: training data [dataframe] 110 | :y_train: target variable 111 | :categorical_embedding_info: output of get_embedding_info function [dictionary of categorical variable and it's embedding size] 112 | :is_classification: True for classification tasks; False for regression tasks 113 | :epochs: num of epochs to train [default:100] 114 | :batch_size: batch size to train [default:256] 115 | 116 | It is a 2 layer neural network architecture with 1000 and 500 neurons with 'ReLU' activation 117 | for classification: loss = 'binary_crossentropy'; metrics = 'accuracy' 118 | for regression: loss = 'mean_squared_error'; metrics = 'r2' 119 | 120 | ''' 121 | 122 | numerical_variables = [x for x in X_train.columns if x not in list(categorical_embedding_info.keys())] 123 | 124 | inputs = [] 125 | flatten_layers = [] 126 | 127 | for var, sz in categorical_embedding_info.items(): 128 | input_c = Input(shape=(1,), dtype='int32') 129 | embed_c = Embedding(*sz, input_length=1)(input_c) 130 | flatten_c = Flatten()(embed_c) 131 | inputs.append(input_c) 132 | flatten_layers.append(flatten_c) 133 | #print(inputs) 134 | 135 | input_num = Input(shape=(len(numerical_variables),), dtype='float32') 136 | flatten_layers.append(input_num) 137 | inputs.append(input_num) 138 | 139 | flatten = concatenate(flatten_layers, axis=-1) 140 | 141 | fc1 = Dense(1000, kernel_initializer='normal')(flatten) 142 | fc1 = Activation('relu')(fc1) 143 | 144 | 145 | 146 | fc2 = Dense(500, kernel_initializer='normal')(fc1) 147 | fc2 = Activation('relu')(fc2) 148 | 149 | 150 | if is_classification: 151 | output = Dense(1, activation='sigmoid')(fc2) 152 | 153 | else: 154 | output = Dense(1, kernel_initializer='normal')(fc2) 155 | 156 | 157 | nnet = Model(inputs=inputs, outputs=output) 158 | 159 | x_inputs = [] 160 | for col in categorical_embedding_info.keys(): 161 | x_inputs.append(X_train[col].values) 162 | 163 | x_inputs.append(X_train[numerical_variables].values) 164 | 165 | if is_classification: 166 | loss = 'binary_crossentropy' 167 | metrics='accuracy' 168 | else: 169 | loss = 'mean_squared_error' 170 | metrics=r2 171 | 172 | 173 | 174 | nnet.compile(loss=loss, optimizer='adam', metrics=[metrics]) 175 | nnet.fit(x_inputs, y_train.values, batch_size=batch_size, epochs=epochs, validation_split=0.2, callbacks=[TQDMNotebookCallback()], verbose=0) 176 | 177 | embs = list(map(lambda x: x.get_weights()[0], [x for x in nnet.layers if 'Embedding' in str(x)])) 178 | embeddings = {var: emb for var, emb in zip(categorical_embedding_info.keys(), embs)} 179 | return embeddings 180 | 181 | def get_embeddings_in_dataframe(embeddings, encoders): 182 | ''' 183 | this function return the embeddings in pandas dataframe 184 | 185 | :embeddings: output of 'get_embeddings' function 186 | :encoders: output of 'get_embedding_info' function 187 | 188 | ''' 189 | 190 | assert len(embeddings)==len(encoders), "Categorical variables in embeddings does not match with those of encoders" 191 | 192 | dfs={} 193 | for cat_var in tqdm_notebook(embeddings.keys()): 194 | df = pd.DataFrame(embeddings[cat_var]) 195 | df.index = encoders[cat_var].classes_ 196 | df.columns = [cat_var + '_embedding_' + str(num) for num in df.columns] 197 | dfs[cat_var] = df 198 | 199 | return dfs 200 | 201 | 202 | def fit_transform(data, embeddings, encoders, drop_categorical_vars=False): 203 | ''' 204 | this function includes the trained embeddings into your data 205 | 206 | :data: input data [dataframe] 207 | :embeddings: output of 'get_embeddings' function 208 | :encoders: output of 'get_embedding_info' function 209 | :drop_categorical_vars: False to keep the categorical variables in the data along with the embeddings 210 | if True - drops the categorical variables and replaces them with trained embeddings 211 | 212 | ''' 213 | 214 | assert len(embeddings)==len(encoders), "Categorical variables in embeddings does not match with those of encoders" 215 | 216 | dfs={} 217 | for cat_var in tqdm_notebook(embeddings.keys()): 218 | df = pd.DataFrame(embeddings[cat_var]) 219 | df.index = encoders[cat_var].classes_ 220 | df.columns = [cat_var + '_embedding_' + str(num) for num in df.columns] 221 | data = data.merge(df, how='left', left_on=cat_var, right_index=True) 222 | 223 | if drop_categorical_vars: 224 | return data.drop(list(embeddings.keys()), axis=1) 225 | else: 226 | return data 227 | -------------------------------------------------------------------------------- /categorical_embedder.egg-info/PKG-INFO: -------------------------------------------------------------------------------- 1 | Metadata-Version: 2.1 2 | Name: categorical-embedder 3 | Version: 0.1 4 | Summary: Categorical Embedder is a python package that let's you convert your categorical variables into numeric via Neural Networks 5 | Home-page: https://github.com/Shivanandroy/CategoricalEmbedder/ 6 | Author: Shivanand Roy 7 | Author-email: Shivanandroy.official@gmail.com 8 | License: UNKNOWN 9 | Description: # Categorical Embedder 10 | 11 | **Categorical Embedder** is a python package that let's you convert your categorical variables into numeric via Neural Networks 12 | 13 | ## Installation 14 | 15 | `pip install categorical_embedder` 16 | 17 | ## Example 18 | ```py 19 | import categorical_embedder as ce 20 | from sklearn.model_selection import train_test_split 21 | 22 | df = pd.read_csv('HR_Attrition_Data.csv') 23 | X = df.drop(['employee_id', 'is_promoted'], axis=1) 24 | y = df['is_promoted'] 25 | 26 | embedding_info = ce.get_embedding_info(X) 27 | X_encoded,encoders = ce.get_label_encoded_data(X) 28 | 29 | X_train, X_test, y_train, y_test = train_test_split(X_encoded,y) 30 | 31 | embeddings = ce.get_embeddings(X_train, y_train, categorical_embedding_info=embedding_info, 32 | is_classification=True, epochs=100,batch_size=256) 33 | ``` 34 | A more detailed [Jupyter Notebook](http://www.github.com ) can be found here 35 | 36 | > What's inside **Categorical Embedder** ? 37 | * `ce.get_embedding_info(data,categorical_variables=None)`: This function identifies all categorical variables in the data, determines its embedding size. Embedding size of the categorical variables are determined by minimum of 50 or half of the no. of its unique values i.e. embedding size of a column = Min(50, # unique values in that column) 38 | One can pass explicit list of categorical variables in `categorical_variables` parameter. If `None`, this function automatically takes all the variables with data type `object` 39 | * `ce.get_label_encoded_data(data, categorical_variables=None)`: This function label encodes (integer encoding) all the categorical variables using sklearn.preprocessing.LabelEncoder and returns a label encoded dataframe for training. Keras/tensorflow or any other deep learning library would expect the data to be in this format. 40 | * `ce.get_embeddings(X_train, y_train, categorical_embedding_info=embedding_info, is_classification=True, epochs=100,batch_size=256)`: This function trains a shallow neural networks and returns embeddings of categorical variables. Under the hood, It is a 2 layer neural network architecture with 1000 and 500 neurons with 'ReLU' activation. It takes 4 required inputs - `X_train`, `y_train`, `categorical_embedding_info`:output of get_embedding_info function and `is_classification`: `True` for classification tasks; `False` for regression tasks. 41 | 42 | For classification: `loss = 'binary_crossentropy'; metrics = 'accuracy'` and for regression: `loss = 'mean_squared_error'; metrics = 'r2'` 43 | 44 | ## Dependencies 45 | ```numpy 46 | pandas 47 | scikit-learn 48 | tensorflow 49 | keras 50 | tqdm 51 | keras-tqdm 52 | ``` 53 | 54 | 55 | 56 | Platform: UNKNOWN 57 | Classifier: Programming Language :: Python :: 3 58 | Classifier: License :: OSI Approved :: MIT License 59 | Classifier: Operating System :: OS Independent 60 | Description-Content-Type: text/markdown 61 | -------------------------------------------------------------------------------- /categorical_embedder.egg-info/SOURCES.txt: -------------------------------------------------------------------------------- 1 | README.md 2 | setup.py 3 | categorical_embedder/__init__.py 4 | categorical_embedder.egg-info/PKG-INFO 5 | categorical_embedder.egg-info/SOURCES.txt 6 | categorical_embedder.egg-info/dependency_links.txt 7 | categorical_embedder.egg-info/requires.txt 8 | categorical_embedder.egg-info/top_level.txt -------------------------------------------------------------------------------- /categorical_embedder.egg-info/dependency_links.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /categorical_embedder.egg-info/requires.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | pandas 3 | tensorflow 4 | keras 5 | tqdm 6 | keras-tqdm 7 | sklearn 8 | -------------------------------------------------------------------------------- /categorical_embedder.egg-info/top_level.txt: -------------------------------------------------------------------------------- 1 | categorical_embedder 2 | -------------------------------------------------------------------------------- /categorical_embedder/__init__.py: -------------------------------------------------------------------------------- 1 | # Necessary imports: 2 | from keras.models import Sequential, Model, model_from_json 3 | from keras.layers import Dense, Dropout, Embedding, Activation, Input, concatenate, Reshape, Flatten 4 | from keras.layers.normalization import BatchNormalization 5 | from keras.layers.advanced_activations import PReLU 6 | from keras.optimizers import Adam 7 | from keras_tqdm import TQDMNotebookCallback 8 | from keras import backend as K 9 | from sklearn.preprocessing import LabelEncoder 10 | from sklearn.utils.validation import check_is_fitted 11 | from sklearn.utils import column_or_1d 12 | from sklearn.model_selection import train_test_split 13 | import numpy as np 14 | import pandas as pd 15 | import warnings 16 | from tqdm import tqdm_notebook 17 | warnings.filterwarnings("ignore") 18 | 19 | # Helper functions: 20 | 21 | class __LabelEncoder__(LabelEncoder): 22 | 23 | def transform(self, y): 24 | 25 | check_is_fitted(self, 'classes_') 26 | y = column_or_1d(y, warn=True) 27 | 28 | unseen = len(self.classes_) 29 | 30 | e = np.array([ 31 | np.searchsorted(self.classes_, x) 32 | if x in self.classes_ else unseen 33 | for x in y 34 | ]) 35 | 36 | if unseen in e: 37 | self.classes_ = np.array(self.classes_.tolist() + ['unseen']) 38 | 39 | return e 40 | 41 | 42 | def get_embedding_info(data, categorical_variables=None): 43 | ''' 44 | this function identifies categorical variables and its embedding size 45 | 46 | :data: input data [dataframe] 47 | :categorical_variables: list of categorical_variables [default: None] 48 | if None, it automatically takes the variables with data type 'object' 49 | 50 | embedding size of categorical variables are determined by minimum of 50 or half of the no. of its unique values. 51 | i.e. embedding size of a column = Min(50, # unique values of that column) 52 | ''' 53 | if categorical_variables is None: 54 | categorical_variables = data.select_dtypes(include='object').columns 55 | 56 | return {col:(data[col].nunique(),min(50,(data[col].nunique()+ 1) //2)) for col in categorical_variables} 57 | 58 | 59 | def get_label_encoded_data(data, categorical_variables=None): 60 | ''' 61 | this function label encodes all the categorical variables using sklearn.preprocessing.labelencoder 62 | and returns a label encoded dataframe for training 63 | 64 | :data: input data [dataframe] 65 | :categorical_variables: list of categorical_variables [Default: None] 66 | if None, it automatically takes the variables with data type 'object' 67 | ''' 68 | encoders = {} 69 | 70 | df = data.copy() 71 | 72 | if categorical_variables is None: 73 | categorical_variables = [col for col in df.columns if df[col].dtype == 'object'] 74 | 75 | for var in categorical_variables: 76 | #print(var) 77 | encoders[var] = __LabelEncoder__() 78 | df.loc[:, var] = encoders[var].fit_transform(df[var]) 79 | 80 | return df, encoders 81 | 82 | 83 | 84 | def r2(y_true, y_pred): 85 | SS_res = K.sum(K.square(y_true - y_pred)) 86 | SS_tot = K.sum(K.square(y_true - K.mean(y_true))) 87 | return (1 - SS_res / (SS_tot + K.epsilon())) 88 | 89 | def precision(y_true, y_pred): 90 | true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1))) 91 | predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1))) 92 | precision = true_positives / (predicted_positives + K.epsilon()) 93 | return precision 94 | 95 | 96 | def recall(y_true, y_pred): 97 | true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1))) 98 | possible_positives = K.sum(K.round(K.clip(y_true, 0, 1))) 99 | recall = true_positives / (possible_positives + K.epsilon()) 100 | return recall 101 | 102 | 103 | # Main function: 104 | 105 | def get_embeddings(X_train, y_train, categorical_embedding_info, is_classification, epochs=100, batch_size=256): 106 | ''' 107 | this function trains a shallow neural networks and returns embeddings of categorical variables 108 | 109 | :X_train: training data [dataframe] 110 | :y_train: target variable 111 | :categorical_embedding_info: output of get_embedding_info function [dictionary of categorical variable and it's embedding size] 112 | :is_classification: True for classification tasks; False for regression tasks 113 | :epochs: num of epochs to train [default:100] 114 | :batch_size: batch size to train [default:256] 115 | 116 | It is a 2 layer neural network architecture with 1000 and 500 neurons with 'ReLU' activation 117 | for classification: loss = 'binary_crossentropy'; metrics = 'accuracy' 118 | for regression: loss = 'mean_squared_error'; metrics = 'r2' 119 | 120 | ''' 121 | 122 | numerical_variables = [x for x in X_train.columns if x not in list(categorical_embedding_info.keys())] 123 | 124 | inputs = [] 125 | flatten_layers = [] 126 | 127 | for var, sz in categorical_embedding_info.items(): 128 | input_c = Input(shape=(1,), dtype='int32') 129 | embed_c = Embedding(*sz, input_length=1)(input_c) 130 | flatten_c = Flatten()(embed_c) 131 | inputs.append(input_c) 132 | flatten_layers.append(flatten_c) 133 | #print(inputs) 134 | 135 | input_num = Input(shape=(len(numerical_variables),), dtype='float32') 136 | flatten_layers.append(input_num) 137 | inputs.append(input_num) 138 | 139 | flatten = concatenate(flatten_layers, axis=-1) 140 | 141 | fc1 = Dense(1000, kernel_initializer='normal')(flatten) 142 | fc1 = Activation('relu')(fc1) 143 | 144 | 145 | 146 | fc2 = Dense(500, kernel_initializer='normal')(fc1) 147 | fc2 = Activation('relu')(fc2) 148 | 149 | 150 | if is_classification: 151 | output = Dense(1, activation='sigmoid')(fc2) 152 | 153 | else: 154 | output = Dense(1, kernel_initializer='normal')(fc2) 155 | 156 | 157 | nnet = Model(inputs=inputs, outputs=output) 158 | 159 | x_inputs = [] 160 | for col in categorical_embedding_info.keys(): 161 | x_inputs.append(X_train[col].values) 162 | 163 | x_inputs.append(X_train[numerical_variables].values) 164 | 165 | if is_classification: 166 | loss = 'binary_crossentropy' 167 | metrics='accuracy' 168 | else: 169 | loss = 'mean_squared_error' 170 | metrics=r2 171 | 172 | 173 | 174 | nnet.compile(loss=loss, optimizer='adam', metrics=[metrics]) 175 | nnet.fit(x_inputs, y_train.values, batch_size=batch_size, epochs=epochs, validation_split=0.2, callbacks=[TQDMNotebookCallback()], verbose=0) 176 | 177 | embs = list(map(lambda x: x.get_weights()[0], [x for x in nnet.layers if 'Embedding' in str(x)])) 178 | embeddings = {var: emb for var, emb in zip(categorical_embedding_info.keys(), embs)} 179 | return embeddings 180 | 181 | def get_embeddings_in_dataframe(embeddings, encoders): 182 | ''' 183 | this function return the embeddings in pandas dataframe 184 | 185 | :embeddings: output of 'get_embeddings' function 186 | :encoders: output of 'get_embedding_info' function 187 | 188 | ''' 189 | 190 | assert len(embeddings)==len(encoders), "Categorical variables in embeddings does not match with those of encoders" 191 | 192 | dfs={} 193 | for cat_var in tqdm_notebook(embeddings.keys()): 194 | df = pd.DataFrame(embeddings[cat_var]) 195 | df.index = encoders[cat_var].classes_ 196 | df.columns = [cat_var + '_embedding_' + str(num) for num in df.columns] 197 | dfs[cat_var] = df 198 | 199 | return dfs 200 | 201 | 202 | def fit_transform(data, embeddings, encoders, drop_categorical_vars=False): 203 | ''' 204 | this function includes the trained embeddings into your data 205 | 206 | :data: input data [dataframe] 207 | :embeddings: output of 'get_embeddings' function 208 | :encoders: output of 'get_embedding_info' function 209 | :drop_categorical_vars: False to keep the categorical variables in the data along with the embeddings 210 | if True - drops the categorical variables and replaces them with trained embeddings 211 | 212 | ''' 213 | 214 | assert len(embeddings)==len(encoders), "Categorical variables in embeddings does not match with those of encoders" 215 | 216 | dfs={} 217 | for cat_var in tqdm_notebook(embeddings.keys()): 218 | df = pd.DataFrame(embeddings[cat_var]) 219 | df.index = encoders[cat_var].classes_ 220 | df.columns = [cat_var + '_embedding_' + str(num) for num in df.columns] 221 | data = data.merge(df, how='left', left_on=cat_var, right_index=True) 222 | 223 | if drop_categorical_vars: 224 | return data.drop(list(embeddings.keys()), axis=1) 225 | else: 226 | return data 227 | -------------------------------------------------------------------------------- /demo/demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shivanandroy/CategoricalEmbedder/f4fcb184a4c520be7990cb4625ea5d040be7d707/demo/demo.gif -------------------------------------------------------------------------------- /dist/categorical_embedder-0.1-py3-none-any.whl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shivanandroy/CategoricalEmbedder/f4fcb184a4c520be7990cb4625ea5d040be7d707/dist/categorical_embedder-0.1-py3-none-any.whl -------------------------------------------------------------------------------- /dist/categorical_embedder-0.1.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shivanandroy/CategoricalEmbedder/f4fcb184a4c520be7990cb4625ea5d040be7d707/dist/categorical_embedder-0.1.tar.gz -------------------------------------------------------------------------------- /example_notebook/Example Notebook.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Categorical Embedder\n", 8 | "### Example" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 3, 14 | "metadata": {}, 15 | "outputs": [], 16 | "source": [ 17 | "# Necessary imports\n", 18 | "import categorical_embedder as ce\n", 19 | "\n", 20 | "import numpy as np\n", 21 | "import pandas as pd\n", 22 | "from sklearn.model_selection import train_test_split\n" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": 4, 28 | "metadata": {}, 29 | "outputs": [ 30 | { 31 | "data": { 32 | "text/plain": [ 33 | "(54808, 14)" 34 | ] 35 | }, 36 | "execution_count": 4, 37 | "metadata": {}, 38 | "output_type": "execute_result" 39 | } 40 | ], 41 | "source": [ 42 | "# Reading data\n", 43 | "df = pd.read_csv('HR_Attrition_Data.csv')\n", 44 | "df.shape" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": 5, 50 | "metadata": {}, 51 | "outputs": [ 52 | { 53 | "data": { 54 | "text/html": [ 55 | "
\n", 56 | "\n", 69 | "\n", 70 | " \n", 71 | " \n", 72 | " \n", 73 | " \n", 74 | " \n", 75 | " \n", 76 | " \n", 77 | " \n", 78 | " \n", 79 | " \n", 80 | " \n", 81 | " \n", 82 | " \n", 83 | " \n", 84 | " \n", 85 | " \n", 86 | " \n", 87 | " \n", 88 | " \n", 89 | " \n", 90 | " \n", 91 | " \n", 92 | " \n", 93 | " \n", 94 | " \n", 95 | " \n", 96 | " \n", 97 | " \n", 98 | " \n", 99 | " \n", 100 | " \n", 101 | " \n", 102 | " \n", 103 | " \n", 104 | " \n", 105 | " \n", 106 | " \n", 107 | " \n", 108 | " \n", 109 | " \n", 110 | " \n", 111 | " \n", 112 | " \n", 113 | " \n", 114 | " \n", 115 | " \n", 116 | " \n", 117 | " \n", 118 | " \n", 119 | " \n", 120 | " \n", 121 | " \n", 122 | " \n", 123 | " \n", 124 | " \n", 125 | " \n", 126 | " \n", 127 | " \n", 128 | " \n", 129 | " \n", 130 | " \n", 131 | " \n", 132 | " \n", 133 | " \n", 134 | " \n", 135 | " \n", 136 | " \n", 137 | " \n", 138 | " \n", 139 | " \n", 140 | " \n", 141 | " \n", 142 | " \n", 143 | " \n", 144 | " \n", 145 | " \n", 146 | " \n", 147 | " \n", 148 | " \n", 149 | " \n", 150 | " \n", 151 | " \n", 152 | " \n", 153 | " \n", 154 | " \n", 155 | " \n", 156 | " \n", 157 | " \n", 158 | " \n", 159 | " \n", 160 | " \n", 161 | " \n", 162 | " \n", 163 | " \n", 164 | " \n", 165 | " \n", 166 | " \n", 167 | " \n", 168 | " \n", 169 | " \n", 170 | " \n", 171 | " \n", 172 | " \n", 173 | " \n", 174 | " \n", 175 | " \n", 176 | "
employee_iddepartmentregioneducationgenderrecruitment_channelno_of_trainingsageprevious_year_ratinglength_of_serviceKPIs_met >80%awards_won?avg_training_scoreis_promoted
065438Sales & Marketingregion_7Master's & abovefsourcing1355.0810490
165141Operationsregion_22Bachelor'smother1305.0400600
27513Sales & Marketingregion_19Bachelor'smsourcing1343.0700500
32542Sales & Marketingregion_23Bachelor'smother2391.01000500
448945Technologyregion_26Bachelor'smother1453.0200730
\n", 177 | "
" 178 | ], 179 | "text/plain": [ 180 | " employee_id department region education gender \\\n", 181 | "0 65438 Sales & Marketing region_7 Master's & above f \n", 182 | "1 65141 Operations region_22 Bachelor's m \n", 183 | "2 7513 Sales & Marketing region_19 Bachelor's m \n", 184 | "3 2542 Sales & Marketing region_23 Bachelor's m \n", 185 | "4 48945 Technology region_26 Bachelor's m \n", 186 | "\n", 187 | " recruitment_channel no_of_trainings age previous_year_rating \\\n", 188 | "0 sourcing 1 35 5.0 \n", 189 | "1 other 1 30 5.0 \n", 190 | "2 sourcing 1 34 3.0 \n", 191 | "3 other 2 39 1.0 \n", 192 | "4 other 1 45 3.0 \n", 193 | "\n", 194 | " length_of_service KPIs_met >80% awards_won? avg_training_score \\\n", 195 | "0 8 1 0 49 \n", 196 | "1 4 0 0 60 \n", 197 | "2 7 0 0 50 \n", 198 | "3 10 0 0 50 \n", 199 | "4 2 0 0 73 \n", 200 | "\n", 201 | " is_promoted \n", 202 | "0 0 \n", 203 | "1 0 \n", 204 | "2 0 \n", 205 | "3 0 \n", 206 | "4 0 " 207 | ] 208 | }, 209 | "execution_count": 5, 210 | "metadata": {}, 211 | "output_type": "execute_result" 212 | } 213 | ], 214 | "source": [ 215 | "df.head()" 216 | ] 217 | }, 218 | { 219 | "cell_type": "code", 220 | "execution_count": 7, 221 | "metadata": {}, 222 | "outputs": [], 223 | "source": [ 224 | "X = df.drop(['employee_id', 'is_promoted'], axis=1)\n", 225 | "y = df['is_promoted']" 226 | ] 227 | }, 228 | { 229 | "cell_type": "code", 230 | "execution_count": 8, 231 | "metadata": {}, 232 | "outputs": [ 233 | { 234 | "data": { 235 | "text/plain": [ 236 | "{'department': (9, 5),\n", 237 | " 'region': (34, 17),\n", 238 | " 'education': (3, 2),\n", 239 | " 'gender': (2, 1),\n", 240 | " 'recruitment_channel': (3, 2)}" 241 | ] 242 | }, 243 | "execution_count": 8, 244 | "metadata": {}, 245 | "output_type": "execute_result" 246 | } 247 | ], 248 | "source": [ 249 | "# ce.get_embedding_info identifies the categorical variables, # of unique values and embedding size and returns a dictionary\n", 250 | "embedding_info = ce.get_embedding_info(X)\n", 251 | "embedding_info" 252 | ] 253 | }, 254 | { 255 | "cell_type": "code", 256 | "execution_count": 10, 257 | "metadata": {}, 258 | "outputs": [ 259 | { 260 | "data": { 261 | "text/html": [ 262 | "
\n", 263 | "\n", 276 | "\n", 277 | " \n", 278 | " \n", 279 | " \n", 280 | " \n", 281 | " \n", 282 | " \n", 283 | " \n", 284 | " \n", 285 | " \n", 286 | " \n", 287 | " \n", 288 | " \n", 289 | " \n", 290 | " \n", 291 | " \n", 292 | " \n", 293 | " \n", 294 | " \n", 295 | " \n", 296 | " \n", 297 | " \n", 298 | " \n", 299 | " \n", 300 | " \n", 301 | " \n", 302 | " \n", 303 | " \n", 304 | " \n", 305 | " \n", 306 | " \n", 307 | " \n", 308 | " \n", 309 | " \n", 310 | " \n", 311 | " \n", 312 | " \n", 313 | " \n", 314 | " \n", 315 | " \n", 316 | " \n", 317 | " \n", 318 | " \n", 319 | " \n", 320 | " \n", 321 | " \n", 322 | " \n", 323 | " \n", 324 | " \n", 325 | " \n", 326 | " \n", 327 | " \n", 328 | " \n", 329 | " \n", 330 | " \n", 331 | " \n", 332 | " \n", 333 | " \n", 334 | " \n", 335 | " \n", 336 | " \n", 337 | " \n", 338 | " \n", 339 | " \n", 340 | " \n", 341 | " \n", 342 | " \n", 343 | " \n", 344 | " \n", 345 | " \n", 346 | " \n", 347 | " \n", 348 | " \n", 349 | " \n", 350 | " \n", 351 | " \n", 352 | " \n", 353 | " \n", 354 | " \n", 355 | " \n", 356 | " \n", 357 | " \n", 358 | " \n", 359 | " \n", 360 | " \n", 361 | " \n", 362 | " \n", 363 | " \n", 364 | " \n", 365 | " \n", 366 | " \n", 367 | " \n", 368 | " \n", 369 | " \n", 370 | " \n", 371 | "
departmentregioneducationgenderrecruitment_channelno_of_trainingsageprevious_year_ratinglength_of_serviceKPIs_met >80%awards_won?avg_training_score
07312021355.081049
14140101305.040060
27100121343.070050
37150102391.0100050
48180101453.020073
\n", 372 | "
" 373 | ], 374 | "text/plain": [ 375 | " department region education gender recruitment_channel \\\n", 376 | "0 7 31 2 0 2 \n", 377 | "1 4 14 0 1 0 \n", 378 | "2 7 10 0 1 2 \n", 379 | "3 7 15 0 1 0 \n", 380 | "4 8 18 0 1 0 \n", 381 | "\n", 382 | " no_of_trainings age previous_year_rating length_of_service \\\n", 383 | "0 1 35 5.0 8 \n", 384 | "1 1 30 5.0 4 \n", 385 | "2 1 34 3.0 7 \n", 386 | "3 2 39 1.0 10 \n", 387 | "4 1 45 3.0 2 \n", 388 | "\n", 389 | " KPIs_met >80% awards_won? avg_training_score \n", 390 | "0 1 0 49 \n", 391 | "1 0 0 60 \n", 392 | "2 0 0 50 \n", 393 | "3 0 0 50 \n", 394 | "4 0 0 73 " 395 | ] 396 | }, 397 | "execution_count": 10, 398 | "metadata": {}, 399 | "output_type": "execute_result" 400 | } 401 | ], 402 | "source": [ 403 | "# ce.get_label_encoded_data integer encodes the categorical variables and prepares it to feed it to neural network\n", 404 | "X_encoded,encoders = ce.get_label_encoded_data(X)\n", 405 | "X_encoded.head()" 406 | ] 407 | }, 408 | { 409 | "cell_type": "code", 410 | "execution_count": 11, 411 | "metadata": {}, 412 | "outputs": [ 413 | { 414 | "data": { 415 | "application/vnd.jupyter.widget-view+json": { 416 | "model_id": "c2cca5478d324f33a4c5e4c72706991b", 417 | "version_major": 2, 418 | "version_minor": 0 419 | }, 420 | "text/plain": [ 421 | 422 | ] 423 | }, 424 | "metadata": {}, 425 | "output_type": "display_data" 426 | }, 427 | 428 | { 429 | "name": "stdout", 430 | "output_type": "stream", 431 | "text": [ 432 | "\n" 433 | ] 434 | } 435 | ], 436 | "source": [ 437 | "# splitting the data into train and test\n", 438 | "X_train, X_test, y_train, y_test = train_test_split(X_encoded,y)\n", 439 | "\n", 440 | "# ce.get_embeddings trains NN, extracts embeddings and return a dictionary containing the embeddings\n", 441 | "embeddings = ce.get_embeddings(X_train, y_train, categorical_embedding_info=embedding_info, \n", 442 | " is_classification=True, epochs=100,batch_size=256)" 443 | ] 444 | }, 445 | { 446 | "cell_type": "code", 447 | "execution_count": 12, 448 | "metadata": {}, 449 | "outputs": [ 450 | { 451 | "data": { 452 | "text/plain": [ 453 | "{'department': array([[ 0.44909748, 0.592682 , -0.2689146 , -0.6076638 , -0.47688553],\n", 454 | " [ 0.14439532, 0.23831578, -0.09904855, -0.1884861 , -0.23708323],\n", 455 | " [ 0.02280043, 0.14768346, 0.00430288, -0.05229405, -0.06076226],\n", 456 | " [ 0.08651688, 0.33048603, -0.10082451, -0.24717978, -0.23439746],\n", 457 | " [ 0.06930665, 0.26183563, -0.099448 , -0.22151738, -0.24915719],\n", 458 | " [ 0.3246719 , 0.13284945, -0.49051526, -0.13767388, -0.35033587],\n", 459 | " [ 0.39557138, 0.6303038 , -0.31711328, -0.6432047 , -0.5024501 ],\n", 460 | " [ 0.105141 , 0.00382448, -0.16800691, 0.14332129, -0.09635292],\n", 461 | " [ 0.5065225 , 0.33804703, -0.4578551 , -0.3261275 , -0.34876052]],\n", 462 | " dtype=float32),\n", 463 | " 'region': array([[ 0.06167015, -0.09331849, -0.00821102, 0.35873163, 0.27501398,\n", 464 | " -0.18806422, 0.42246535, 0.16405596, 0.10364748, -0.01732335,\n", 465 | " 0.08967754, -0.4844684 , 0.1706062 , 0.07629129, -0.46060166,\n", 466 | " -0.250795 , -0.20905156],\n", 467 | " [ 0.15227915, -0.12810177, 0.04858813, 0.12176999, 0.24337043,\n", 468 | " -0.29888666, 0.28821293, 0.16845648, 0.09924047, -0.07546613,\n", 469 | " 0.20567334, -0.1352939 , 0.08037616, 0.11891732, -0.18827821,\n", 470 | " -0.0816901 , -0.2779371 ],\n", 471 | " [ 0.10410061, -0.2104857 , 0.03398978, 0.44951865, 0.20713083,\n", 472 | " -0.04877395, 0.14342847, 0.26965714, 0.07883961, -0.09597544,\n", 473 | " 0.20250317, -0.01135268, 0.21519178, 0.09570241, -0.17490706,\n", 474 | " -0.35628 , -0.2990015 ],\n", 475 | " [ 0.16129175, -0.18185025, 0.17493466, 0.26707488, 0.1694759 ,\n", 476 | " -0.09626276, 0.12718718, 0.21795489, 0.21134973, -0.11628247,\n", 477 | " 0.1386758 , -0.08041578, 0.16929497, 0.18136315, -0.19477838,\n", 478 | " -0.20189697, -0.17830281],\n", 479 | " [ 0.09706695, -0.19466406, 0.09813315, 0.17285876, 0.14107579,\n", 480 | " -0.19056891, 0.13303407, 0.13789149, 0.19315891, -0.13101685,\n", 481 | " 0.14132002, -0.14025746, 0.2021801 , 0.1314787 , -0.20561455,\n", 482 | " -0.21507941, -0.19575128],\n", 483 | " [ 0.14802098, -0.17255649, 0.14310662, 0.11609807, 0.11347289,\n", 484 | " -0.16585284, 0.10970658, 0.12739295, 0.19427484, -0.14321955,\n", 485 | " 0.20610353, -0.28366777, 0.15577143, 0.19353567, -0.21552865,\n", 486 | " -0.11984877, -0.10363703],\n", 487 | " [ 0.17970042, -0.18435585, 0.2475642 , 0.04125668, 0.18393014,\n", 488 | " -0.2055146 , 0.20275053, 0.12765794, 0.18785101, -0.1876641 ,\n", 489 | " 0.16109551, -0.06933317, 0.19538549, 0.2865587 , -0.1244712 ,\n", 490 | " -0.04614212, -0.04774917],\n", 491 | " [ 0.1652677 , -0.18662676, 0.17290325, 0.23617998, 0.17689778,\n", 492 | " -0.18750137, 0.1188752 , 0.14298196, 0.15700985, -0.21114749,\n", 493 | " 0.11973537, -0.10869645, 0.12464906, 0.18343486, -0.1341778 ,\n", 494 | " -0.1960731 , -0.14148588],\n", 495 | " [ 0.3271403 , -0.20113371, 0.21159333, -0.10972808, 0.10864414,\n", 496 | " -0.10503934, -0.05350627, 0.06617992, 0.30653876, -0.2582119 ,\n", 497 | " 0.20189317, -0.08169097, 0.1769267 , 0.31721008, -0.03993533,\n", 498 | " -0.12251526, -0.03543968],\n", 499 | " [ 0.49103948, -0.2007103 , 0.6056571 , -0.5402952 , -0.10961785,\n", 500 | " -0.12077125, -0.3009251 , -0.265473 , 0.787379 , -0.27324182,\n", 501 | " 0.14153819, -0.71441156, 0.06250548, 0.693984 , 0.03170918,\n", 502 | " 0.15812905, 0.21315786],\n", 503 | " [ 0.145392 , -0.14274164, 0.15068698, 0.22807477, 0.2243693 ,\n", 504 | " -0.21936256, 0.20291108, 0.19875804, 0.16094448, -0.11601669,\n", 505 | " 0.1970405 , -0.17728642, 0.179156 , 0.16524756, -0.2120733 ,\n", 506 | " -0.18664533, -0.15705594],\n", 507 | " [ 0.1262483 , -0.11328196, 0.0883195 , 0.22781551, 0.22445866,\n", 508 | " -0.14024241, 0.23455557, 0.18774618, 0.11778355, -0.1181474 ,\n", 509 | " 0.14379439, -0.23156083, 0.16031402, 0.11968873, -0.24753772,\n", 510 | " -0.16268265, -0.17190713],\n", 511 | " [ 0.23302501, -0.20846668, 0.1955039 , 0.15930428, 0.15813963,\n", 512 | " -0.20879413, 0.15088034, 0.1160711 , 0.19584629, -0.19314173,\n", 513 | " 0.18014538, 0.08821156, 0.20087811, 0.24331777, -0.09434079,\n", 514 | " -0.16702151, -0.11515248],\n", 515 | " [ 0.28533554, -0.17394873, 0.28509635, 0.22203562, 0.22575945,\n", 516 | " -0.15222326, 0.11049732, 0.2120857 , 0.16156977, -0.23141564,\n", 517 | " 0.24763632, 0.22988008, 0.2439133 , 0.2348142 , -0.02985817,\n", 518 | " -0.14367686, -0.16408649],\n", 519 | " [ 0.25052086, -0.14586733, 0.16408962, 0.0507778 , 0.13811931,\n", 520 | " -0.09476005, 0.06604631, 0.15955344, 0.15546182, -0.19960758,\n", 521 | " 0.13166466, -0.1917724 , 0.19645979, 0.21055347, -0.13072439,\n", 522 | " -0.21560103, -0.14708862],\n", 523 | " [ 0.15596968, -0.096491 , 0.2651417 , 0.0172758 , 0.15028748,\n", 524 | " -0.00850711, 0.01992784, 0.03281375, 0.292076 , -0.05552545,\n", 525 | " 0.10409604, -0.48480755, 0.13718422, 0.3167684 , -0.19583265,\n", 526 | " -0.18808214, -0.01285333],\n", 527 | " [ 0.14563759, -0.20024602, 0.10641596, 0.29899475, 0.17734939,\n", 528 | " -0.17023069, 0.13852215, 0.3017684 , 0.11778694, -0.16828966,\n", 529 | " 0.24574074, 0.05353521, 0.2246284 , 0.0889548 , -0.16901372,\n", 530 | " -0.3217882 , -0.32629952],\n", 531 | " [ 0.2638189 , -0.12104501, 0.0855846 , 0.03768857, 0.17154141,\n", 532 | " -0.1629056 , 0.11832082, 0.08007334, 0.17927748, -0.11453451,\n", 533 | " 0.22714885, -0.33345717, 0.06720065, 0.21814711, -0.17346677,\n", 534 | " -0.06780529, -0.20093854],\n", 535 | " [ 0.25641713, -0.16766143, 0.28437874, 0.2066256 , 0.18670245,\n", 536 | " -0.16507165, 0.06784681, 0.18451001, 0.13122343, -0.29931074,\n", 537 | " 0.17858475, 0.07724513, 0.19355355, 0.23004377, -0.09734175,\n", 538 | " -0.12910831, -0.08001687],\n", 539 | " [ 0.16712832, -0.18242997, 0.2117202 , 0.07055467, 0.12361196,\n", 540 | " -0.24991389, 0.17608581, 0.22188991, 0.23487635, -0.15777037,\n", 541 | " 0.13870557, 0.00351198, 0.13455318, 0.18125626, -0.12799425,\n", 542 | " -0.16291897, -0.12791933],\n", 543 | " [ 0.18096384, -0.14182384, 0.21343406, 0.1738329 , 0.12259424,\n", 544 | " -0.0726446 , 0.04926556, 0.16455975, 0.31186566, -0.13384931,\n", 545 | " 0.10702487, -0.07755449, 0.18760253, 0.28193462, -0.0987668 ,\n", 546 | " -0.14985453, -0.07982537],\n", 547 | " [ 0.07552038, -0.23918265, -0.02730448, 0.31189373, 0.21811438,\n", 548 | " -0.20897698, 0.24290127, 0.25800923, 0.09748205, -0.11501911,\n", 549 | " 0.13102911, -0.12043046, 0.18119033, 0.01878749, -0.2620637 ,\n", 550 | " -0.2751533 , -0.29920864],\n", 551 | " [ 0.26392713, -0.20426226, 0.16867518, -0.17597376, 0.20485684,\n", 552 | " -0.39169973, 0.2588289 , 0.1385429 , 0.20912765, -0.14703502,\n", 553 | " 0.19859625, -0.21648338, 0.03238742, 0.22506148, -0.153454 ,\n", 554 | " -0.02380576, -0.12335522],\n", 555 | " [ 0.24141465, -0.16093457, 0.23083627, 0.04343623, 0.12495948,\n", 556 | " -0.1838547 , 0.12102125, 0.10463292, 0.24187995, -0.18191727,\n", 557 | " 0.2392767 , -0.13200898, 0.16538313, 0.20711762, -0.08967044,\n", 558 | " -0.10814557, -0.19145034],\n", 559 | " [ 0.11075923, -0.18888554, 0.25069082, 0.038624 , 0.2005392 ,\n", 560 | " -0.28150958, 0.16630962, 0.2176757 , 0.15907326, -0.24305162,\n", 561 | " 0.13209617, 0.10084943, 0.16876855, 0.2158818 , -0.09895746,\n", 562 | " -0.09712282, -0.21898103],\n", 563 | " [ 0.03361392, -0.09406326, 0.09817785, 0.2577049 , 0.30454093,\n", 564 | " -0.18041614, 0.3392363 , 0.27354133, 0.04589733, -0.05175021,\n", 565 | " 0.18609524, -0.30759484, 0.23998573, 0.05730164, -0.4153282 ,\n", 566 | " -0.21307011, -0.15814604],\n", 567 | " [ 0.15096396, -0.13050652, 0.12978545, 0.3539185 , 0.25473082,\n", 568 | " -0.15637578, 0.29674196, 0.25461993, 0.03402874, -0.15752971,\n", 569 | " 0.19972315, -0.03535189, 0.18373075, 0.06211969, -0.19313417,\n", 570 | " -0.15184213, -0.24135719],\n", 571 | " [ 0.29311594, -0.2278486 , 0.10147502, 0.45177343, 0.24334185,\n", 572 | " -0.2585309 , 0.27289554, 0.3444534 , -0.11046959, -0.30788636,\n", 573 | " 0.28252742, 0.527343 , 0.18462309, 0.10991604, 0.01184443,\n", 574 | " -0.140023 , -0.29176837],\n", 575 | " [ 0.22212954, -0.0972956 , 0.0757604 , 0.00119282, 0.10628817,\n", 576 | " -0.04451968, 0.01875878, 0.04793382, 0.29772508, -0.13264635,\n", 577 | " 0.21271782, -0.420673 , 0.14043371, 0.2398636 , -0.14792743,\n", 578 | " -0.17283182, -0.13257244],\n", 579 | " [ 0.22231904, -0.18607658, 0.10973643, 0.2119783 , 0.24073566,\n", 580 | " -0.18088289, 0.21400924, 0.18904465, 0.05207201, -0.21133952,\n", 581 | " 0.247929 , 0.11467847, 0.21312265, 0.19659802, -0.10511905,\n", 582 | " -0.21219096, -0.17890845],\n", 583 | " [ 0.11948653, -0.13136931, 0.10426308, 0.3505519 , 0.31600142,\n", 584 | " -0.32948446, 0.30525216, 0.3088015 , -0.00783463, -0.22288117,\n", 585 | " 0.19944812, 0.05813633, 0.11685727, 0.13461244, -0.1513958 ,\n", 586 | " -0.13962963, -0.33049056],\n", 587 | " [ 0.02869872, -0.18674235, 0.17713551, 0.06570877, 0.16253975,\n", 588 | " -0.14332786, 0.1743341 , 0.05855466, 0.22767399, -0.09349857,\n", 589 | " 0.07442123, -0.38788933, 0.17868695, 0.25265324, -0.2979716 ,\n", 590 | " -0.11825583, -0.03945841],\n", 591 | " [ 0.37929267, -0.22558057, 0.29641664, 0.13798186, 0.17428468,\n", 592 | " -0.16322204, 0.10453338, 0.18077698, 0.1327219 , -0.27045852,\n", 593 | " 0.2009385 , 0.1341174 , 0.14552666, 0.2806392 , -0.02933537,\n", 594 | " -0.10095037, -0.17500046],\n", 595 | " [ 0.409364 , -0.26337704, 0.21481998, 0.16067739, 0.15580739,\n", 596 | " -0.3264027 , 0.19733264, 0.27058062, 0.07919446, -0.4325755 ,\n", 597 | " 0.31672868, 0.647895 , 0.1496081 , 0.23755707, 0.04658863,\n", 598 | " -0.06037766, -0.24323867]], dtype=float32),\n", 599 | " 'education': array([[-0.38068944, 0.28985274],\n", 600 | " [-0.308726 , 0.30023918],\n", 601 | " [-0.18988381, 0.46631435]], dtype=float32),\n", 602 | " 'gender': array([[0.38233313],\n", 603 | " [0.38486874]], dtype=float32),\n", 604 | " 'recruitment_channel': array([[-0.29799673, 0.37393624],\n", 605 | " [-0.31355464, 0.3837454 ],\n", 606 | " [-0.39704534, 0.2786043 ]], dtype=float32)}" 607 | ] 608 | }, 609 | "execution_count": 12, 610 | "metadata": {}, 611 | "output_type": "execute_result" 612 | } 613 | ], 614 | "source": [ 615 | "embeddings" 616 | ] 617 | }, 618 | { 619 | "cell_type": "code", 620 | "execution_count": 13, 621 | "metadata": {}, 622 | "outputs": [ 623 | { 624 | "data": { 625 | "application/vnd.jupyter.widget-view+json": { 626 | "model_id": "c7bfbe036d2d48b7a2a1203e0cf8e19b", 627 | "version_major": 2, 628 | "version_minor": 0 629 | }, 630 | "text/plain": [ 631 | 632 | ] 633 | }, 634 | "metadata": {}, 635 | "output_type": "display_data" 636 | }, 637 | { 638 | "name": "stdout", 639 | "output_type": "stream", 640 | "text": [ 641 | "\n" 642 | ] 643 | } 644 | ], 645 | "source": [ 646 | "# if you don't like the dictionary format; convert it to dataframe for easy readibility\n", 647 | "dfs = ce.get_embeddings_in_dataframe(embeddings=embeddings, encoders=encoders)" 648 | ] 649 | }, 650 | { 651 | "cell_type": "code", 652 | "execution_count": 15, 653 | "metadata": {}, 654 | "outputs": [ 655 | { 656 | "data": { 657 | "text/html": [ 658 | "
\n", 659 | "\n", 672 | "\n", 673 | " \n", 674 | " \n", 675 | " \n", 676 | " \n", 677 | " \n", 678 | " \n", 679 | " \n", 680 | " \n", 681 | " \n", 682 | " \n", 683 | " \n", 684 | " \n", 685 | " \n", 686 | " \n", 687 | " \n", 688 | " \n", 689 | " \n", 690 | " \n", 691 | " \n", 692 | " \n", 693 | " \n", 694 | " \n", 695 | " \n", 696 | " \n", 697 | " \n", 698 | " \n", 699 | " \n", 700 | " \n", 701 | " \n", 702 | " \n", 703 | " \n", 704 | " \n", 705 | " \n", 706 | " \n", 707 | " \n", 708 | " \n", 709 | " \n", 710 | " \n", 711 | " \n", 712 | " \n", 713 | " \n", 714 | " \n", 715 | " \n", 716 | " \n", 717 | " \n", 718 | " \n", 719 | " \n", 720 | " \n", 721 | " \n", 722 | " \n", 723 | " \n", 724 | " \n", 725 | " \n", 726 | " \n", 727 | " \n", 728 | " \n", 729 | " \n", 730 | " \n", 731 | " \n", 732 | " \n", 733 | " \n", 734 | " \n", 735 | " \n", 736 | " \n", 737 | " \n", 738 | " \n", 739 | " \n", 740 | " \n", 741 | " \n", 742 | " \n", 743 | " \n", 744 | " \n", 745 | " \n", 746 | " \n", 747 | " \n", 748 | " \n", 749 | " \n", 750 | " \n", 751 | " \n", 752 | " \n", 753 | " \n", 754 | " \n", 755 | " \n", 756 | " \n", 757 | "
department_embedding_0department_embedding_1department_embedding_2department_embedding_3department_embedding_4
Analytics0.4490970.592682-0.268915-0.607664-0.476886
Finance0.1443950.238316-0.099049-0.188486-0.237083
HR0.0228000.1476830.004303-0.052294-0.060762
Legal0.0865170.330486-0.100825-0.247180-0.234397
Operations0.0693070.261836-0.099448-0.221517-0.249157
Procurement0.3246720.132849-0.490515-0.137674-0.350336
R&D0.3955710.630304-0.317113-0.643205-0.502450
Sales & Marketing0.1051410.003824-0.1680070.143321-0.096353
Technology0.5065220.338047-0.457855-0.326127-0.348761
\n", 758 | "
" 759 | ], 760 | "text/plain": [ 761 | " department_embedding_0 department_embedding_1 \\\n", 762 | "Analytics 0.449097 0.592682 \n", 763 | "Finance 0.144395 0.238316 \n", 764 | "HR 0.022800 0.147683 \n", 765 | "Legal 0.086517 0.330486 \n", 766 | "Operations 0.069307 0.261836 \n", 767 | "Procurement 0.324672 0.132849 \n", 768 | "R&D 0.395571 0.630304 \n", 769 | "Sales & Marketing 0.105141 0.003824 \n", 770 | "Technology 0.506522 0.338047 \n", 771 | "\n", 772 | " department_embedding_2 department_embedding_3 \\\n", 773 | "Analytics -0.268915 -0.607664 \n", 774 | "Finance -0.099049 -0.188486 \n", 775 | "HR 0.004303 -0.052294 \n", 776 | "Legal -0.100825 -0.247180 \n", 777 | "Operations -0.099448 -0.221517 \n", 778 | "Procurement -0.490515 -0.137674 \n", 779 | "R&D -0.317113 -0.643205 \n", 780 | "Sales & Marketing -0.168007 0.143321 \n", 781 | "Technology -0.457855 -0.326127 \n", 782 | "\n", 783 | " department_embedding_4 \n", 784 | "Analytics -0.476886 \n", 785 | "Finance -0.237083 \n", 786 | "HR -0.060762 \n", 787 | "Legal -0.234397 \n", 788 | "Operations -0.249157 \n", 789 | "Procurement -0.350336 \n", 790 | "R&D -0.502450 \n", 791 | "Sales & Marketing -0.096353 \n", 792 | "Technology -0.348761 " 793 | ] 794 | }, 795 | "execution_count": 15, 796 | "metadata": {}, 797 | "output_type": "execute_result" 798 | } 799 | ], 800 | "source": [ 801 | "dfs['department']" 802 | ] 803 | }, 804 | { 805 | "cell_type": "code", 806 | "execution_count": 16, 807 | "metadata": {}, 808 | "outputs": [ 809 | { 810 | "data": { 811 | "text/html": [ 812 | "
\n", 813 | "\n", 826 | "\n", 827 | " \n", 828 | " \n", 829 | " \n", 830 | " \n", 831 | " \n", 832 | " \n", 833 | " \n", 834 | " \n", 835 | " \n", 836 | " \n", 837 | " \n", 838 | " \n", 839 | " \n", 840 | " \n", 841 | " \n", 842 | " \n", 843 | " \n", 844 | " \n", 845 | " \n", 846 | " \n", 847 | " \n", 848 | " \n", 849 | " \n", 850 | " \n", 851 | "
education_embedding_0education_embedding_1
Bachelor's-0.3806890.289853
Below Secondary-0.3087260.300239
Master's & above-0.1898840.466314
\n", 852 | "
" 853 | ], 854 | "text/plain": [ 855 | " education_embedding_0 education_embedding_1\n", 856 | "Bachelor's -0.380689 0.289853\n", 857 | "Below Secondary -0.308726 0.300239\n", 858 | "Master's & above -0.189884 0.466314" 859 | ] 860 | }, 861 | "execution_count": 16, 862 | "metadata": {}, 863 | "output_type": "execute_result" 864 | } 865 | ], 866 | "source": [ 867 | "dfs['education']" 868 | ] 869 | }, 870 | { 871 | "cell_type": "code", 872 | "execution_count": 20, 873 | "metadata": {}, 874 | "outputs": [ 875 | { 876 | "data": { 877 | "application/vnd.jupyter.widget-view+json": { 878 | "model_id": "f4ab55d9f7f2449486bc246aeb2bf686", 879 | "version_major": 2, 880 | "version_minor": 0 881 | }, 882 | "text/plain": [ 883 | "HBox(children=(IntProgress(value=0, max=5), HTML(value='')))" 884 | ] 885 | }, 886 | "metadata": {}, 887 | "output_type": "display_data" 888 | }, 889 | { 890 | "name": "stdout", 891 | "output_type": "stream", 892 | "text": [ 893 | "\n" 894 | ] 895 | }, 896 | { 897 | "data": { 898 | "text/html": [ 899 | "
\n", 900 | "\n", 913 | "\n", 914 | " \n", 915 | " \n", 916 | " \n", 917 | " \n", 918 | " \n", 919 | " \n", 920 | " \n", 921 | " \n", 922 | " \n", 923 | " \n", 924 | " \n", 925 | " \n", 926 | " \n", 927 | " \n", 928 | " \n", 929 | " \n", 930 | " \n", 931 | " \n", 932 | " \n", 933 | " \n", 934 | " \n", 935 | " \n", 936 | " \n", 937 | " \n", 938 | " \n", 939 | " \n", 940 | " \n", 941 | " \n", 942 | " \n", 943 | " \n", 944 | " \n", 945 | " \n", 946 | " \n", 947 | " \n", 948 | " \n", 949 | " \n", 950 | " \n", 951 | " \n", 952 | " \n", 953 | " \n", 954 | " \n", 955 | " \n", 956 | " \n", 957 | " \n", 958 | " \n", 959 | " \n", 960 | " \n", 961 | " \n", 962 | " \n", 963 | " \n", 964 | " \n", 965 | " \n", 966 | " \n", 967 | " \n", 968 | " \n", 969 | " \n", 970 | " \n", 971 | " \n", 972 | " \n", 973 | " \n", 974 | " \n", 975 | " \n", 976 | " \n", 977 | " \n", 978 | " \n", 979 | " \n", 980 | " \n", 981 | " \n", 982 | " \n", 983 | " \n", 984 | " \n", 985 | " \n", 986 | " \n", 987 | " \n", 988 | " \n", 989 | " \n", 990 | " \n", 991 | " \n", 992 | " \n", 993 | " \n", 994 | " \n", 995 | " \n", 996 | " \n", 997 | " \n", 998 | " \n", 999 | " \n", 1000 | " \n", 1001 | " \n", 1002 | " \n", 1003 | " \n", 1004 | " \n", 1005 | " \n", 1006 | " \n", 1007 | " \n", 1008 | " \n", 1009 | " \n", 1010 | " \n", 1011 | " \n", 1012 | " \n", 1013 | " \n", 1014 | " \n", 1015 | " \n", 1016 | " \n", 1017 | " \n", 1018 | " \n", 1019 | " \n", 1020 | " \n", 1021 | " \n", 1022 | " \n", 1023 | " \n", 1024 | " \n", 1025 | " \n", 1026 | " \n", 1027 | " \n", 1028 | " \n", 1029 | " \n", 1030 | " \n", 1031 | " \n", 1032 | " \n", 1033 | " \n", 1034 | " \n", 1035 | " \n", 1036 | " \n", 1037 | " \n", 1038 | " \n", 1039 | " \n", 1040 | " \n", 1041 | " \n", 1042 | " \n", 1043 | " \n", 1044 | " \n", 1045 | " \n", 1046 | " \n", 1047 | " \n", 1048 | " \n", 1049 | " \n", 1050 | " \n", 1051 | " \n", 1052 | " \n", 1053 | " \n", 1054 | " \n", 1055 | " \n", 1056 | " \n", 1057 | " \n", 1058 | " \n", 1059 | " \n", 1060 | " \n", 1061 | " \n", 1062 | "
no_of_trainingsageprevious_year_ratinglength_of_serviceKPIs_met >80%awards_won?avg_training_scoredepartment_embedding_0department_embedding_1department_embedding_2...region_embedding_12region_embedding_13region_embedding_14region_embedding_15region_embedding_16education_embedding_0education_embedding_1gender_embedding_0recruitment_channel_embedding_0recruitment_channel_embedding_1
01355.0810490.1051410.003824-0.168007...0.1786870.252653-0.297972-0.118256-0.039458-0.1898840.4663140.382333-0.3970450.278604
11305.0400600.0693070.261836-0.099448...0.1964600.210553-0.130724-0.215601-0.147089-0.3806890.2898530.384869-0.2979970.373936
21343.0700500.1051410.003824-0.168007...0.1791560.165248-0.212073-0.186645-0.157056-0.3806890.2898530.384869-0.3970450.278604
32391.01000500.1051410.003824-0.168007...0.1371840.316768-0.195833-0.188082-0.012853-0.3806890.2898530.384869-0.2979970.373936
41453.0200730.5065220.338047-0.457855...0.1935540.230044-0.097342-0.129108-0.080017-0.3806890.2898530.384869-0.2979970.373936
\n", 1063 | "

5 rows × 34 columns

\n", 1064 | "
" 1065 | ], 1066 | "text/plain": [ 1067 | " no_of_trainings age previous_year_rating length_of_service \\\n", 1068 | "0 1 35 5.0 8 \n", 1069 | "1 1 30 5.0 4 \n", 1070 | "2 1 34 3.0 7 \n", 1071 | "3 2 39 1.0 10 \n", 1072 | "4 1 45 3.0 2 \n", 1073 | "\n", 1074 | " KPIs_met >80% awards_won? avg_training_score department_embedding_0 \\\n", 1075 | "0 1 0 49 0.105141 \n", 1076 | "1 0 0 60 0.069307 \n", 1077 | "2 0 0 50 0.105141 \n", 1078 | "3 0 0 50 0.105141 \n", 1079 | "4 0 0 73 0.506522 \n", 1080 | "\n", 1081 | " department_embedding_1 department_embedding_2 ... region_embedding_12 \\\n", 1082 | "0 0.003824 -0.168007 ... 0.178687 \n", 1083 | "1 0.261836 -0.099448 ... 0.196460 \n", 1084 | "2 0.003824 -0.168007 ... 0.179156 \n", 1085 | "3 0.003824 -0.168007 ... 0.137184 \n", 1086 | "4 0.338047 -0.457855 ... 0.193554 \n", 1087 | "\n", 1088 | " region_embedding_13 region_embedding_14 region_embedding_15 \\\n", 1089 | "0 0.252653 -0.297972 -0.118256 \n", 1090 | "1 0.210553 -0.130724 -0.215601 \n", 1091 | "2 0.165248 -0.212073 -0.186645 \n", 1092 | "3 0.316768 -0.195833 -0.188082 \n", 1093 | "4 0.230044 -0.097342 -0.129108 \n", 1094 | "\n", 1095 | " region_embedding_16 education_embedding_0 education_embedding_1 \\\n", 1096 | "0 -0.039458 -0.189884 0.466314 \n", 1097 | "1 -0.147089 -0.380689 0.289853 \n", 1098 | "2 -0.157056 -0.380689 0.289853 \n", 1099 | "3 -0.012853 -0.380689 0.289853 \n", 1100 | "4 -0.080017 -0.380689 0.289853 \n", 1101 | "\n", 1102 | " gender_embedding_0 recruitment_channel_embedding_0 \\\n", 1103 | "0 0.382333 -0.397045 \n", 1104 | "1 0.384869 -0.297997 \n", 1105 | "2 0.384869 -0.397045 \n", 1106 | "3 0.384869 -0.297997 \n", 1107 | "4 0.384869 -0.297997 \n", 1108 | "\n", 1109 | " recruitment_channel_embedding_1 \n", 1110 | "0 0.278604 \n", 1111 | "1 0.373936 \n", 1112 | "2 0.278604 \n", 1113 | "3 0.373936 \n", 1114 | "4 0.373936 \n", 1115 | "\n", 1116 | "[5 rows x 34 columns]" 1117 | ] 1118 | }, 1119 | "execution_count": 20, 1120 | "metadata": {}, 1121 | "output_type": "execute_result" 1122 | } 1123 | ], 1124 | "source": [ 1125 | "# include these embeddings in your dataset:\n", 1126 | "data = ce.fit_transform(X, embeddings=embeddings, encoders=encoders, drop_categorical_vars=True)\n", 1127 | "data.head()" 1128 | ] 1129 | } 1130 | ], 1131 | "metadata": { 1132 | "kernelspec": { 1133 | "display_name": "Python 3", 1134 | "language": "python", 1135 | "name": "python3" 1136 | }, 1137 | "language_info": { 1138 | "codemirror_mode": { 1139 | "name": "ipython", 1140 | "version": 3 1141 | }, 1142 | "file_extension": ".py", 1143 | "mimetype": "text/x-python", 1144 | "name": "python", 1145 | "nbconvert_exporter": "python", 1146 | "pygments_lexer": "ipython3", 1147 | "version": "3.7.4" 1148 | } 1149 | }, 1150 | "nbformat": 4, 1151 | "nbformat_minor": 2 1152 | } 1153 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | with open("README.md", "r") as fh: 4 | long_description = fh.read() 5 | 6 | setuptools.setup( 7 | #Here is the module name. 8 | name="categorical_embedder", 9 | 10 | #version of the module 11 | version="0.1", 12 | 13 | #Name of Author 14 | author="Shivanand Roy", 15 | 16 | #your Email address 17 | author_email="Shivanandroy.official@gmail.com", 18 | 19 | #Small Description about module 20 | description="Categorical Embedder is a python package that let's you convert your categorical variables into numeric via Neural Networks", 21 | 22 | long_description=long_description, 23 | 24 | #Specifying that we are using markdown file for description 25 | long_description_content_type="text/markdown", 26 | url="https://github.com/Shivanandroy/CategoricalEmbedder/", 27 | packages=setuptools.find_packages(), 28 | 29 | #classifiers like program is suitable for python3, just leave as it is. 30 | classifiers=[ 31 | "Programming Language :: Python :: 3", 32 | "License :: OSI Approved :: MIT License", 33 | "Operating System :: OS Independent", 34 | ], 35 | install_requires=['numpy','pandas','tensorflow','keras','tqdm','keras-tqdm','sklearn'] 36 | ) 37 | --------------------------------------------------------------------------------