├── report.pdf ├── similar_word.py ├── main.py ├── naiveBayesW2V.py ├── naiveBayesCountVec.py ├── utils.py ├── lstm_predict.py ├── word2vec_visualize.py ├── config.py ├── preprocess.py ├── README.md ├── word2vec.py ├── lstm.py ├── .gitignore └── preprocess.ipynb /report.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/coffee-cup/mbti/HEAD/report.pdf -------------------------------------------------------------------------------- /similar_word.py: -------------------------------------------------------------------------------- 1 | from gensim.models import Word2Vec 2 | 3 | from utils import get_config 4 | 5 | if __name__ == '__main__': 6 | config, unparsed = get_config(return_unparsed=True) 7 | 8 | if len(unparsed) != 1: 9 | print('Please provided a word') 10 | exit(1) 11 | 12 | word = unparsed[0] 13 | 14 | model = Word2Vec.load(config.embeddings_model) 15 | print('Words similar to ' + word) 16 | for w, s in model.most_similar(word): 17 | print('{:>15} {:.2f}%'.format(w, s * 100)) 18 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # The main file to run all steps 2 | 3 | from preprocess import preprocess 4 | from utils import (ALL, FIRST, FOURTH, SECOND, THIRD, get_char_for_binary, 5 | get_config, one_hot_to_type) 6 | from word2vec import word2vec 7 | 8 | if __name__ == '__main__': 9 | config = get_config() 10 | 11 | preprocess(config) 12 | 13 | # 16 classes 14 | embedding_data = word2vec(config, code=ALL) 15 | example = embedding_data[10] 16 | print('Label is {}'.format(one_hot_to_type(example[1]))) 17 | 18 | # Binary class (third class) 19 | # code = FOURTH 20 | # embedding_data = word2vec(config, code=code) 21 | # example = embedding_data[10] 22 | # print('Binary label is {}'.format(get_char_for_binary(code, example[1]))) 23 | -------------------------------------------------------------------------------- /naiveBayesW2V.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from sklearn.feature_extraction.text import CountVectorizer 4 | from sklearn import datasets 5 | from sklearn.naive_bayes import GaussianNB 6 | from utils import * 7 | from sklearn.metrics import classification_report 8 | from word2vec import word2vec 9 | from sklearn.cross_validation import train_test_split 10 | 11 | 12 | def meanEmbeddingTransform(feature): 13 | return np.array([[np.mean([np.mean(sentence) for sentence in sentences if np.any(sentence)==True] or [0] )] for sentences in feature]) 14 | 15 | # with open("./data/preprocessed.csv") as labelFile: 16 | 17 | config = get_config() 18 | data = word2vec(config) 19 | # w2v = {line[0]: line[1:] for line in data} 20 | 21 | # print w2v.keys 22 | data_train, data_test = train_test_split(data, test_size = 0.2) 23 | # print data_train 24 | y_train, X_train = map(list, zip(*data_train)) 25 | y_test, X_test = map(list, zip(*data_test)) 26 | 27 | y_train = np.array(y_train) 28 | 29 | 30 | y_test = np.array(y_test) 31 | 32 | X_train = meanEmbeddingTransform(X_train) 33 | X_test = meanEmbeddingTransform(X_test) 34 | 35 | clf = GaussianNB() 36 | clf.fit(X_train,y_train) 37 | prediction = clf.predict(X_test) 38 | print classification_report(y_test, prediction) -------------------------------------------------------------------------------- /naiveBayesCountVec.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from sklearn.feature_extraction.text import CountVectorizer 4 | from sklearn.naive_bayes import MultinomialNB 5 | import csv 6 | from sklearn.metrics import classification_report 7 | from sklearn.cross_validation import train_test_split 8 | 9 | X_data = [] 10 | y_data = [] 11 | 12 | def regex_tokenizer(doc): 13 | """Return a function that split a string in sequence of tokens""" 14 | return doc.split(' ') 15 | 16 | with open("./data/preprocessed.csv") as csvFile: 17 | reader = csv.reader(csvFile) 18 | x = 0 19 | for line in reader: 20 | if line[1] ==" ": 21 | continue 22 | X_data.append(line[1]) 23 | y_data.append(line[2]) 24 | x +=1; 25 | if x == 20000: 26 | break 27 | 28 | X_train, X_test, y_train, y_test = train_test_split(X_data, y_data, test_size = 0.2) 29 | 30 | vectorizer = CountVectorizer(lowercase=False, stop_words=None, max_df=1.0, min_df=1, max_features=None, tokenizer=regex_tokenizer ) 31 | 32 | X_train = vectorizer.fit_transform(X_train).toarray() 33 | X_test = vectorizer.transform(X_test).toarray() 34 | 35 | print X_train 36 | 37 | y_test = np.asarray(y_test) 38 | y_train = np.asarray(y_train) 39 | 40 | print len(X_train) 41 | print len(y_train) 42 | 43 | clf = MultinomialNB() 44 | clf.fit(X_train,y_train) 45 | prediction = clf.predict(X_test) 46 | print classification_report(y_test, prediction) -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # Utils file for useful functions 2 | 3 | import os 4 | 5 | import numpy as np 6 | 7 | from config import parse_config, print_usage 8 | 9 | first_codes = ['I', 'E'] 10 | second_codes = ['S', 'N'] 11 | third_codes = ['T', 'F'] 12 | fourth_codes = ['J', 'P'] 13 | codes = [first_codes, second_codes, third_codes, fourth_codes] 14 | 15 | FIRST = 0 16 | SECOND = 1 17 | THIRD = 2 18 | FOURTH = 3 19 | ALL = 16 20 | 21 | personality_types = [ 22 | 'ISTJ', 'ISFJ', 'INFJ', 'INTJ', 'ISTP', 'ISFP', 'INFP', 'INTP', 'ESTP', 23 | 'ESFP', 'ENFP', 'ENTP', 'ESTJ', 'ESFJ', 'ENFJ', 'ENTJ' 24 | ] 25 | 26 | 27 | def one_hot_encode_type(t): 28 | i = personality_types.index(t) 29 | Y = np.zeros(len(personality_types)) 30 | Y[i] = 1 31 | return Y.astype(int).tolist() 32 | 33 | 34 | def one_hot_to_type(Y): 35 | i = np.where(np.array(Y) == 1)[0][0] 36 | return personality_types[i] 37 | 38 | 39 | def get_binary_for_code(code, personality_type): 40 | c = codes[code] 41 | return int(personality_type[code] != c[0]) 42 | 43 | 44 | def get_char_for_binary(code, binary): 45 | if type(binary) is list: 46 | binary = binary[0] 47 | c = codes[code] 48 | return c[binary] 49 | 50 | 51 | def get_config(return_unparsed=False): 52 | """Gets config and creates data_dir.""" 53 | config, unparsed = parse_config() 54 | 55 | # If we have unparsed args, print usage and exit 56 | if len(unparsed) > 0 and not return_unparsed: 57 | print_usage() 58 | exit(1) 59 | 60 | def append_data_dir(p): 61 | return os.path.join(config.data_dir, p) 62 | 63 | # Append data_dir to all filepaths 64 | config.pre_save_file = append_data_dir(config.pre_save_file) 65 | config.raw_csv_file = append_data_dir(config.raw_csv_file) 66 | config.embeddings_model = append_data_dir(config.embeddings_model) 67 | config.embeddings_file = append_data_dir(config.embeddings_file) 68 | 69 | # Create data_dir if it doesn't exist 70 | if not os.path.exists(config.data_dir): 71 | os.makedirs(config.data_dir) 72 | 73 | if return_unparsed: 74 | return config, unparsed 75 | 76 | return config 77 | -------------------------------------------------------------------------------- /lstm_predict.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import sys 4 | 5 | import numpy as np 6 | import torch 7 | import torch.autograd as autograd 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torch.optim as optim 11 | from torch.autograd import Variable 12 | 13 | from lstm import LSTMClassifier, MbtiDataset 14 | from preprocess import preprocess_text 15 | from utils import (FIRST, FOURTH, SECOND, THIRD, codes, get_char_for_binary, 16 | get_config) 17 | from word2vec import load_word2vec, word2vec 18 | 19 | 20 | def np_sentence_to_list(L_sent): 21 | newsent = [] 22 | for sentance in L_sent: 23 | temp = [] 24 | for word in sentance: 25 | temp.append(word.tolist()) 26 | newsent.append(temp) 27 | return newsent 28 | 29 | 30 | def load_model(config, code): 31 | model_file = 'saves/{}_model'.format(code) 32 | model = LSTMClassifier( 33 | config, 34 | embedding_dim=config.feature_size, 35 | hidden_dim=128, 36 | label_size=2) 37 | model.load_state_dict(torch.load(model_file)) 38 | return model 39 | 40 | 41 | def predict(config, text, code, model=None, embedding_input=None): 42 | if model is None: 43 | model = load_model(config, code) 44 | 45 | preprocessed = preprocess_text(text) 46 | 47 | if embedding_input is None: 48 | embedding = [] 49 | word_model = load_word2vec(config.embeddings_model) 50 | for word in preprocessed.split(' '): 51 | if word in word_model.wv.index2word: 52 | vec = word_model.wv[word] 53 | embedding.append(vec) 54 | 55 | embedding_input = Variable( 56 | torch.Tensor(np_sentence_to_list(embedding))) 57 | 58 | pred = model(embedding_input) 59 | pred_label = pred.data.max(1)[1].numpy()[0] 60 | pred_char = get_char_for_binary(code, pred_label) 61 | return pred_char 62 | 63 | 64 | if __name__ == '__main__': 65 | config = get_config() 66 | 67 | # Python 2/3 input 68 | try: 69 | input = raw_input 70 | except NameError: 71 | pass 72 | 73 | if sys.stdin.isatty(): 74 | text = input('Enter some text: ') 75 | else: 76 | text = sys.stdin.read() 77 | 78 | personality = '' 79 | codes = [FIRST, SECOND, THIRD, FOURTH] 80 | for code in codes: 81 | personality += predict(config, text, code) 82 | 83 | print('Prediction is {}'.format(personality)) 84 | -------------------------------------------------------------------------------- /word2vec_visualize.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | import random 4 | 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | from gensim.models import Word2Vec 8 | from mpl_toolkits.mplot3d import Axes3D 9 | from sklearn.manifold import TSNE 10 | 11 | from utils import get_config 12 | 13 | 14 | def load_model(config): 15 | """Load the word2vec model from disk.""" 16 | return Word2Vec.load(config.embeddings_model) 17 | 18 | 19 | def get_vocab(model, n): 20 | """Returns n labels and vectors belonging to word2vec model.""" 21 | labels = [] 22 | tokens = [] 23 | 24 | i = 0 25 | 26 | items = list(model.wv.vocab.items()) 27 | random.shuffle(items) 28 | for word, _ in items: 29 | tokens.append(model[word]) 30 | labels.append(word) 31 | 32 | i += 1 33 | if i >= n: 34 | break 35 | 36 | return labels, tokens 37 | 38 | 39 | def plot3d(labels, tokens): 40 | """Plot word vectors in 3d.""" 41 | print('Plotting {} points in 3d'.format(len(labels))) 42 | 43 | # Reduce dimensionality with TSNE 44 | tsne_model = TSNE( 45 | perplexity=40, 46 | n_components=3, 47 | init='pca', 48 | n_iter=2500, 49 | learning_rate=600) 50 | new_values = tsne_model.fit_transform(tokens) 51 | 52 | x = [] 53 | y = [] 54 | z = [] 55 | for value in new_values: 56 | x.append(value[0]) 57 | y.append(value[1]) 58 | z.append(value[2]) 59 | 60 | fig = plt.figure() 61 | ax = fig.add_subplot(111, projection='3d') 62 | 63 | ax.scatter(x, y, z, c='b', marker='.', edgecolors='none') 64 | 65 | for i, l in enumerate(labels): 66 | ax.text(x[i], y[i], z[i], l) 67 | 68 | plt.show() 69 | 70 | 71 | def plot2d(labels, tokens): 72 | """Plot word vectors in 2d.""" 73 | print('Plotting {} points in 2d'.format(len(labels))) 74 | 75 | # Reduce dimensionality with TSNE 76 | tsne_model = TSNE( 77 | perplexity=40, 78 | n_components=2, 79 | init='pca', 80 | n_iter=2000, 81 | learning_rate=500) 82 | new_values = tsne_model.fit_transform(tokens) 83 | 84 | x = [] 85 | y = [] 86 | for value in new_values: 87 | x.append(value[0]) 88 | y.append(value[1]) 89 | 90 | fig = plt.figure() 91 | ax = fig.add_subplot(111) 92 | 93 | for i in range(len(x)): 94 | ax.scatter(x[i], y[i], marker='.', c='b', s=100, edgecolors='none') 95 | plt.annotate( 96 | labels[i], 97 | xy=(x[i], y[i]), 98 | xytext=(5, 2), 99 | textcoords='offset points', 100 | ha='right', 101 | va='bottom') 102 | 103 | plt.show() 104 | 105 | 106 | if __name__ == '__main__': 107 | config = get_config() 108 | model = load_model(config) 109 | 110 | N = 100 111 | labels, tokens = get_vocab(model, N) 112 | plot2d(labels, tokens) 113 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | # Global variables within this script 4 | arg_lists = [] 5 | parser = argparse.ArgumentParser() 6 | 7 | 8 | # Some nice macros to be used for arparse 9 | def str2bool(v): 10 | return v.lower() in ('true', '1') 11 | 12 | 13 | def add_argument_group(name): 14 | arg = parser.add_argument_group(name) 15 | arg_lists.append(arg) 16 | return arg 17 | 18 | 19 | parser.add_argument( 20 | '--data_dir', 21 | type=str, 22 | default='./data', 23 | help='Directory to save/read all data files from') 24 | 25 | parser.add_argument( 26 | '--use_cuda', 27 | type=str2bool, 28 | default=False, 29 | help='Whether or not to use cuda for PyTorch') 30 | 31 | # Arguments for preprocessing 32 | preprocessing_arg = add_argument_group('Preprocessing') 33 | 34 | preprocessing_arg.add_argument( 35 | '--raw_csv_file', 36 | type=str, 37 | default='mbti_1.csv', 38 | help='Filename of csv file downloaded from Kaggle') 39 | 40 | preprocessing_arg.add_argument( 41 | '--pre_save_file', 42 | type=str, 43 | default='mbti_preprocessed.csv', 44 | help='Filename to save preprocessed csv file as') 45 | 46 | preprocessing_arg.add_argument( 47 | '--force_preprocessing', 48 | type=str2bool, 49 | default=False, 50 | help='Whether or not to do preprocessing even if output csv file is found') 51 | 52 | # Arguments for word2vec 53 | word2vec_arg = add_argument_group('Word2Vec') 54 | 55 | word2vec_arg.add_argument( 56 | '--embeddings_model', 57 | type=str, 58 | default='embeddings_model', 59 | help='Filename to save word2vec model to') 60 | 61 | word2vec_arg.add_argument( 62 | '--embeddings_file', 63 | type=str, 64 | default='vector_data', 65 | help='Filename to save mbti data with word vectors to') 66 | 67 | word2vec_arg.add_argument( 68 | '--num_threads', 69 | type=int, 70 | default=4, 71 | help='Number of threads to use for training word2vec') 72 | 73 | word2vec_arg.add_argument( 74 | '--feature_size', 75 | type=int, 76 | default=300, 77 | help='Number of features to use for word2vec') 78 | 79 | word2vec_arg.add_argument( 80 | '--min_words', 81 | type=int, 82 | default=10, 83 | help='Minimum number of words for word2vec') 84 | 85 | word2vec_arg.add_argument( 86 | '--distance_between_words', 87 | type=int, 88 | default=10, 89 | help='Distance between words for word2vec') 90 | 91 | word2vec_arg.add_argument( 92 | '--epochs', 93 | type=int, 94 | default=10, 95 | help='Number of epochs to train word2vec for') 96 | 97 | word2vec_arg.add_argument( 98 | '--force_word2vec', 99 | type=str2bool, 100 | default=False, 101 | help= 102 | 'Whether or not to create word embeddings even if output word2vec file is found' 103 | ) 104 | 105 | word2vec_arg.add_argument( 106 | '--num_samples', 107 | type=int, 108 | default=-1, 109 | help='Number of samples to return from word2vec. -1 for all samples') 110 | 111 | lstm_arg = add_argument_group('LSTM') 112 | 113 | lstm_arg.add_argument( 114 | '--batch_size', 115 | type=int, 116 | default=32, 117 | help='Number of samples to use per iteration when training lstm') 118 | 119 | 120 | def parse_config(): 121 | config, unparsed = parser.parse_known_args() 122 | 123 | return config, unparsed 124 | 125 | 126 | def print_usage(): 127 | parser.print_usage() 128 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import re 3 | import string 4 | 5 | import numpy as np 6 | import pandas as pd 7 | from nltk.corpus import stopwords 8 | from nltk.tokenize import word_tokenize 9 | 10 | from utils import * 11 | 12 | # Regular expression to match punctuation 13 | reg_punc = re.compile('[%s]' % re.escape(string.punctuation)) 14 | 15 | # Regular expression to match links 16 | reg_link = re.compile('http\S+', flags=re.MULTILINE) 17 | 18 | # Regular expression to match non-characters 19 | reg_alpha = re.compile('[^a-zA-Z ]') 20 | 21 | # Regular expression to match all whitespace 22 | reg_spaces = re.compile('\s+', flags=re.MULTILINE) 23 | 24 | 25 | def filter_text(post): 26 | """Decide whether or not we want to use the post.""" 27 | return len(post.split(' ')) >= 7 28 | 29 | 30 | def preprocess_text(post): 31 | """Remove any junk we don't want to use in the post.""" 32 | 33 | # Remove links 34 | post = reg_link.sub('', post) 35 | 36 | # All lowercase 37 | post = post.lower() 38 | 39 | # Remove non-alpha chars 40 | post = reg_alpha.sub('', post) 41 | 42 | # Replace multiple whitespace with single space 43 | post = reg_spaces.sub(' ', post) 44 | 45 | # Remove stop words 46 | stop_words = set(stopwords.words('english')) 47 | word_tokens = word_tokenize(post) 48 | post = [w for w in word_tokens if not w in stop_words] 49 | post = ' '.join(post) 50 | 51 | # Strip whitespace 52 | posts = post.strip() 53 | 54 | return post 55 | 56 | 57 | def create_new_rows(row): 58 | """Create new rows of the data by preprocessing the individual posts and filtering out bad ones.""" 59 | posts = row['posts'].split('|||') 60 | rows = [] 61 | 62 | for p in posts: 63 | p = preprocess_text(p) 64 | if not filter_text(p): 65 | continue 66 | rows.append({'type': row['type'], 'post': p}) 67 | return rows 68 | 69 | 70 | def preprocess(config): 71 | """Preprocess the data using the config. 72 | 73 | :config user configuration 74 | """ 75 | print('\n--- Preprocessing') 76 | 77 | if os.path.isfile(config.pre_save_file) and not config.force_preprocessing: 78 | df = pd.read_csv(config.pre_save_file) 79 | return df.values 80 | 81 | df = pd.read_csv(config.raw_csv_file) 82 | newrows = [] 83 | for index, row in df.iterrows(): 84 | newrows += create_new_rows(row) 85 | 86 | df = pd.DataFrame(newrows) 87 | df.to_csv(config.pre_save_file) 88 | 89 | return df.values 90 | 91 | 92 | def get_count(posts, fn): 93 | counts_dict = {} 94 | for row in posts: 95 | label = row[-1] 96 | l = fn(label) 97 | if counts_dict.get(l) is None: 98 | counts_dict[l] = 1 99 | else: 100 | counts_dict[l] += 1 101 | 102 | counts = [] 103 | counts = map(lambda x: (x[0], x[1]), counts_dict.items()) 104 | return sorted(counts, key=lambda t: -t[1]) 105 | 106 | 107 | def print_counts(counts): 108 | total = 0.0 109 | for t in counts: 110 | total += t[1] 111 | 112 | for t in counts: 113 | percent = (t[1] / total) * 100 114 | print('{} {} {:.2f}%'.format(t[0], t[1], percent)) 115 | print('') 116 | 117 | 118 | if __name__ == '__main__': 119 | config = get_config() 120 | posts = preprocess(config) 121 | 122 | # Visualize the preprocessing 123 | print('Preprocess Complete!') 124 | print('{} Total rows'.format(len(posts))) 125 | print('Here are the first 2 rows') 126 | print(posts[0:2]) 127 | 128 | print('\nCount of labels for all 16 classes and each character\n') 129 | 130 | print('All') 131 | print_counts(get_count(posts, lambda x: x)) 132 | 133 | print('First') 134 | print_counts(get_count(posts, lambda x: x[0])) 135 | 136 | print('Second') 137 | print_counts(get_count(posts, lambda x: x[1])) 138 | 139 | print('Third') 140 | print_counts(get_count(posts, lambda x: x[2])) 141 | 142 | print('Fourth') 143 | print_counts(get_count(posts, lambda x: x[3])) 144 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MBTI Personality Type Predictor 2 | 3 | Personality type predictor using PyTorch. Attempts to predicts the 16 way personality code as well as each binary character code along each axis. 4 | 5 | View the [report here](https://github.com/coffee-cup/uvic-data-mining-mbti/blob/master/report.pdf). 6 | 7 | ## Dataset 8 | 9 | Download and extract the csv file from [Kaggle](https://www.kaggle.com/datasnaek/mbti-type/version/1). 10 | 11 | Place the extracted file in a directory called `data`. 12 | 13 | ## Usage 14 | 15 | You can run all the steps at once with 16 | 17 | ```sh 18 | python main.py 19 | ``` 20 | 21 | Or run each step individually 22 | 23 | ```sh 24 | python preprocess.py 25 | python word2vec.py 26 | ``` 27 | 28 | To see all of the config variables, run 29 | 30 | ```sh 31 | python main --help 32 | ``` 33 | 34 | ``` 35 | usage: main.py [-h] [--data_dir DATA_DIR] [--raw_csv_file RAW_CSV_FILE] 36 | [--pre_save_file PRE_SAVE_FILE] 37 | [--force_preprocessing FORCE_PREPROCESSING] 38 | [--embeddings_model EMBEDDINGS_MODEL] 39 | [--embeddings_file EMBEDDINGS_FILE] [--num_threads NUM_THREADS] 40 | [--feature_size FEATURE_SIZE] [--min_words MIN_WORDS] 41 | [--distance_between_words DISTANCE_BETWEEN_WORDS] 42 | [--epochs EPOCHS] [--force_word2vec FORCE_WORD2VEC] 43 | [--num_samples NUM_SAMPLES] 44 | 45 | optional arguments: 46 | -h, --help show this help message and exit 47 | --data_dir DATA_DIR Directory to save/read all data files from 48 | 49 | Preprocessing: 50 | --raw_csv_file RAW_CSV_FILE 51 | Filename of csv file downloaded from Kaggle 52 | --pre_save_file PRE_SAVE_FILE 53 | Filename to save preprocessed csv file as 54 | --force_preprocessing FORCE_PREPROCESSING 55 | Whether or not to do preprocessing even if output csv 56 | file is found 57 | 58 | Word2Vec: 59 | --embeddings_model EMBEDDINGS_MODEL 60 | Filename to save word2vec model to 61 | --embeddings_file EMBEDDINGS_FILE 62 | Filename to save mbti data with word vectors to 63 | --num_threads NUM_THREADS 64 | Number of threads to use for training word2vec 65 | --feature_size FEATURE_SIZE 66 | Number of features to use for word2vec 67 | --min_words MIN_WORDS 68 | Minimum number of words for word2vec 69 | --distance_between_words DISTANCE_BETWEEN_WORDS 70 | Distance between words for word2vec 71 | --epochs EPOCHS Number of epochs to train word2vec for 72 | --force_word2vec FORCE_WORD2VEC 73 | Whether or not to create word embeddings even if 74 | output word2vec file is found 75 | --num_samples NUM_SAMPLES 76 | Number of samples to return from word2vec. -1 for all 77 | samples 78 | ``` 79 | 80 | # Data Format 81 | 82 | The format of the data can get a little confusing. Hopefully this clears things up. 83 | 84 | For the following, `N = number of rows (samples) we have`. 85 | 86 | _Note: All filepaths are prefixed with the `data` directory._ 87 | 88 | ## Preprocessing 89 | 90 | ### Input 91 | 92 | Raw CSV file coming from Kaggle. The location of the input file is given by `config.raw_csv_file`. 93 | 94 | ```python 95 | preprocess(config) # nothing returned, new csv file saved 96 | ``` 97 | 98 | ### Output 99 | 100 | The file is preprocessed by splitting each row into a new row for each individual post. Stopwords, numbers, links, and punctuation are removed and the text is set to all lowercase. The file is saved to `config.pre_save_file`. 101 | 102 | ## Word2Vec 103 | 104 | This is the data that will be mainly used for training/testing. Multiple output types can be specified depending if you are training to classify all 16 classes, or doing a binary classification for each of the 4 character codes. 105 | 106 | ### Input 107 | 108 | Preprocessed CSV file. The location of the input file is given by `config.pre_save_file`. 109 | 110 | As input you also need to give the personality "character code". The options are imported from `utils.py`. 111 | 112 | ```python 113 | from utils import FIRST, FOURTH, SECOND, THIRD 114 | embedding_data = word2vec(config, code=ALL) # Defaults to ALL 115 | ``` 116 | 117 | ### Output 118 | 119 | The output will all be numbers, no strings will be present. 120 | 121 | For each row, the first element will be the sentence data and the second element will be the label vector. 122 | 123 | ``` 124 | row = [sentence, label] 125 | ``` 126 | 127 | #### Sentence 128 | 129 | The sentence data is a list of words vectors, so may be a different length for each row. 130 | 131 | #### Word Vector 132 | 133 | Each word vector is a vector of length `config.feature_size`, which defaults to 300. 134 | 135 | #### Label 136 | 137 | The label depends on the `code` option specified. 138 | 139 | **ALL** 140 | 141 | The label will be a length 16 vector which is one-hot encoded. You can use the `utils.one_hot_to_type` function to convert from a one-hot encoding to a personality type. 142 | 143 | _For example_ 144 | 145 | ```python 146 | # Get one hot encoding 147 | Y = one_hot_encode_type('INTJ') 148 | print(Y) 149 | # => [0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 150 | 151 | # Get personality type 152 | t = one_hot_to_type(Y) 153 | print(t) 154 | # => INTJ 155 | 156 | ``` 157 | 158 | **FIRST, SECOND, THIRD, FOURTH** 159 | 160 | The label will be a length 1 vector which is either 0 or 1. When training, the output should be just a binary classification. To get what the character was based on the binary classification, you can use the `utils.get_char_for_binary` function. 161 | 162 | _For example_ 163 | 164 | ```python 165 | # Consider the third character (T or F) 166 | code = THIRD 167 | 168 | # Get binary class 169 | b = get_binary_for_code(code, 'ESTP') 170 | print(b) 171 | # => 0 172 | 173 | # Get character for class 174 | c = get_char_for_binary(code, b) 175 | print(c) 176 | # => T 177 | ``` 178 | -------------------------------------------------------------------------------- /word2vec.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import os 3 | import pickle 4 | import random 5 | import re 6 | 7 | import numpy as np 8 | import pandas as pd 9 | from gensim.models import Word2Vec 10 | 11 | from tqdm import trange 12 | from utils import (ALL, FIRST, FOURTH, SECOND, THIRD, get_binary_for_code, 13 | get_config, one_hot_encode_type) 14 | 15 | 16 | def create_word2vec_model(data, config): 17 | """Train a word2vec model 18 | 19 | :data Data to train on 20 | :num_threads Number of threads to use while training 21 | :feature_size Number of features to use 22 | :min_words Ignore all words with total frequency lower than this 23 | :distance_between_words Maximum distance between the 24 | current and predicted word within a sentence 25 | :epochs Number of iterations to run over the corpus 26 | """ 27 | print('Training model...') 28 | model = Word2Vec( 29 | data, 30 | workers=config.num_threads, 31 | size=config.feature_size, 32 | min_count=config.min_words, 33 | window=config.distance_between_words, 34 | sample=1e-3, 35 | sorted_vocab=1, 36 | iter=config.epochs) 37 | model.init_sims(replace=True) 38 | 39 | return model 40 | 41 | 42 | def load_word2vec(name): 43 | """Load word2vec model from file.""" 44 | return Word2Vec.load(name) 45 | 46 | 47 | def extract_words(data): 48 | """Extract words from posts. 49 | 50 | :data preprocessed data 51 | """ 52 | words = [] 53 | for row in data: 54 | post = row[1] 55 | words.append(post.split()) 56 | 57 | return words 58 | 59 | 60 | def convert_data_to_index(posts, model): 61 | """Match each word of a post to the index in the model. 62 | Note: not every word is in the model 63 | """ 64 | index_data = [] 65 | 66 | for post in posts: 67 | index_of_post = [] 68 | for word in post: 69 | 70 | if word in model.wv: 71 | index_of_post.append(model.wv.vocab[word].index) 72 | 73 | index_data.append(index_of_post) 74 | 75 | return index_data 76 | 77 | 78 | def convert_posts_to_vectors(config, data, model): 79 | """Convert the text post to a vector 80 | 81 | :data df.read_csv(pre_processed_csv).values 82 | :model word2vec model 83 | 84 | Returns array of length N of format [mbti type, sentence array] 85 | sentence array is an array of words where each word is a 300 dim vector 86 | """ 87 | print('Converting text data to vectors...') 88 | 89 | N = len(data) 90 | # Get number of samples to use from config if not -1 91 | if config.num_samples != -1: 92 | N = config.num_samples 93 | 94 | # Shuffle data 95 | random.shuffle(data) 96 | embedded_data = [] 97 | for idx in trange(N): 98 | row = data[idx] 99 | post = row[1] 100 | mbti_type = row[2] 101 | 102 | sentence = [] 103 | for word in post.split(' '): 104 | if word in model.wv.index2word: 105 | vec = model.wv[word] 106 | sentence.append(vec) 107 | if len(sentence) > 0: 108 | embedded_data.append([mbti_type, sentence]) 109 | 110 | return embedded_data 111 | 112 | 113 | def get_embeddings(model): 114 | """Convert the keyedVectors of the model into numpy arrays.""" 115 | num_features = len(model[list(model.wv.vocab.keys())[0]]) 116 | 117 | embedded_weights = np.zeros((len(model.wv.vocab), num_features)) 118 | for i in range(len(model.wv.vocab)): 119 | embedding_vector = model.wv[model.wv.index2word[i]] 120 | if embedding_vector is not None: 121 | embedded_weights[i] = embedding_vector 122 | 123 | return embedded_weights 124 | 125 | 126 | def get_code_data(code, embedding_data): 127 | """Get data with label as binary specifying a specific personality type code.""" 128 | newdata = [] 129 | for row in embedding_data: 130 | c = get_binary_for_code(code, row[0]) 131 | newdata.append([row[1], [c]]) 132 | return newdata 133 | 134 | 135 | def get_one_hot_data(embedding_data): 136 | """Get data with label one-hot encoded for all possible classes.""" 137 | newdata = [] 138 | for row in embedding_data: 139 | Y = one_hot_encode_type(row[0]) 140 | newdata.append([row[1], Y]) 141 | return newdata 142 | 143 | 144 | def word2vec(config, code=ALL, batch=True, pre_data=None): 145 | """Create word2vec embeddings 146 | 147 | :config user configuration 148 | """ 149 | print('\n--- Creating word embeddings') 150 | 151 | if pre_data is None: 152 | pre_data = pd.read_csv(config.pre_save_file).values 153 | embedding_data = None 154 | if os.path.isfile(config.embeddings_model) and not config.force_word2vec: 155 | # Load model from file 156 | model = load_word2vec(config.embeddings_model) 157 | else: 158 | # Train model 159 | words = extract_words(pre_data) 160 | model = create_word2vec_model(words, config) 161 | 162 | # Save model to disk 163 | model.save(config.embeddings_model) 164 | 165 | # Create data with labels and word embeddings 166 | embedding_data = convert_posts_to_vectors(config, pre_data, model) 167 | 168 | if code == ALL: 169 | embedding_data = get_one_hot_data(embedding_data) 170 | else: 171 | embedding_data = get_code_data(code, embedding_data) 172 | 173 | return batch_embeddings(embedding_data) if batch else embedding_data 174 | 175 | 176 | def average_sentence_length(embedding_data): 177 | """Returns the average number of words in a sentence""" 178 | total_words = 0 179 | for row in embedding_data: 180 | total_words += float(len(row[1])) 181 | return total_words / len(embedding_data) 182 | 183 | 184 | def batch_embeddings(embeddings): 185 | batched_embeddings = [] 186 | lengths = set() 187 | for row in embeddings: 188 | lengths.add(len(row[0])) 189 | max_length = max(lengths) 190 | #min_length = len(min(l[0] for l in embeddings)) 191 | #print(max_length) 192 | for i in range(0, max_length + 1): 193 | temp = [] 194 | for row in embeddings: 195 | 196 | if (len(row[0]) == i): 197 | temp.append(row) 198 | if (len(temp) != 0): 199 | batched_embeddings.append(temp) 200 | #print(len(temp)) 201 | 202 | return batched_embeddings 203 | 204 | 205 | # SHIT DON'T WORK 206 | # def get_embeddings(config): 207 | # """Returns data rows with embeddings from disk.""" 208 | # data = pd.HDFStore(config.embeddings_file) 209 | # return data['data'].values 210 | 211 | # def save_embeddings(config, embedding_data): 212 | # """Saves data rows with embeddings to disk.""" 213 | # x = pd.HDFStore(config.embeddings_file) 214 | # x.append('data', pd.DataFrame(embedding_data)) 215 | # x.close() 216 | 217 | if __name__ == "__main__": 218 | config = get_config() 219 | 220 | embedding_data = word2vec(config) 221 | batched_embeddings = batch_embeddings(embedding_data) 222 | print(len(embedding_data), len(batched_embeddings)) 223 | print('Created word embeddings') 224 | 225 | print('Rows: {}'.format(len(embedding_data))) 226 | print('Average number of words: {}'.format( 227 | average_sentence_length(embedding_data))) 228 | -------------------------------------------------------------------------------- /lstm.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import random 3 | 4 | import numpy as np 5 | import pandas as pd 6 | import torch 7 | import torch.autograd as autograd 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torch.optim as optim 11 | from sklearn.metrics import precision_recall_fscore_support 12 | from sklearn.model_selection import train_test_split 13 | from torch.autograd import Variable 14 | from torch.utils.data import DataLoader, Dataset 15 | 16 | from utils import FIRST, FOURTH, SECOND, THIRD, codes, get_config 17 | from word2vec import word2vec 18 | 19 | # random.seed(1) 20 | # torch.manual_seed(1) 21 | 22 | 23 | class MbtiDataset(Dataset): 24 | def __init__(self, X, y): 25 | self.X = X 26 | self.y = y 27 | 28 | def __len__(self): 29 | return len(self.X) 30 | 31 | def __getitem__(self, idx): 32 | return self.X[idx], self.y[idx] 33 | 34 | 35 | class LSTMClassifier(nn.Module): 36 | def __init__(self, config, embedding_dim, hidden_dim, label_size): 37 | super(LSTMClassifier, self).__init__() 38 | self.config = config 39 | self.label_size = label_size 40 | self.hidden_dim = hidden_dim 41 | # self.word_embeddings = nn.Embedding(vocab_size, embedding_dim) 42 | self.lstm = nn.LSTM(embedding_dim, hidden_dim, dropout=0.2) 43 | self.hidden2label = nn.Linear(hidden_dim, label_size) 44 | self.hidden = self.init_hidden() 45 | 46 | def init_hidden(self): 47 | h1 = autograd.Variable(torch.zeros(1, 1, self.hidden_dim)) 48 | h2 = autograd.Variable(torch.zeros(1, 1, self.hidden_dim)) 49 | 50 | if self.config.use_cuda: 51 | h1 = h1 52 | h2 = h2 53 | return (h1, h2) 54 | 55 | def forward(self, embeds): 56 | x = embeds 57 | # print('embeds size {}'.format(embeds.size())) 58 | # x = embeds.view(embeds.size(0), 1, -1) 59 | lstm_out, self.hidden = self.lstm(x, self.hidden) 60 | y = self.hidden2label(lstm_out[-1]) 61 | log_probs = F.log_softmax(y) 62 | return log_probs 63 | 64 | 65 | def train_epoch(model, dataloader, loss_fn, optimizer, epoch): 66 | '''Train a single epoch.''' 67 | model.train() 68 | 69 | avg_loss = 0.0 70 | count = 0 71 | total_samples = 0 72 | correct = 0.0 73 | truth_res = [] 74 | pred_res = [] 75 | 76 | for i_batch, sample_batched in enumerate(dataloader): 77 | inputs, labels = sample_batched 78 | inputs = Variable(torch.stack(inputs)) 79 | labels = Variable(torch.stack(labels)).view(-1) 80 | 81 | # truth_res.append(labels.data[0]) 82 | model.hidden = model.init_hidden() 83 | 84 | output = model(inputs) 85 | _, predict = torch.max(output, 1) 86 | 87 | correct += (predict.data.numpy() == labels.data.numpy()).sum() 88 | total_samples += labels.size()[0] 89 | # pred_label = pred.data.max(1)[1].numpy()[0] 90 | # pred_res.append(pred_label) 91 | 92 | optimizer.zero_grad() 93 | loss = loss_fn(output, labels) 94 | avg_loss += loss.data[0] 95 | count += 1 96 | 97 | if count % 100 == 0: 98 | print('\tIteration: {} Loss: {}'.format(epoch, count, 99 | loss.data[0])) 100 | 101 | loss.backward() 102 | optimizer.step() 103 | 104 | avg_loss /= count 105 | acc = correct / total_samples 106 | # acc = get_accuracy(truth_res, pred_res) 107 | print('Epoch: {} Avg Loss: {} Acc: {:.2f}%'.format(epoch, avg_loss, 108 | acc * 100)) 109 | return avg_loss, acc 110 | 111 | 112 | def evaluate(model, dataloader): 113 | model.eval() 114 | 115 | truth_res = [] 116 | pred_res = [] 117 | 118 | correct = 0.0 119 | total_samples = 0 120 | 121 | for i_batch, sample_batched in enumerate(dataloader): 122 | inputs, labels = sample_batched 123 | inputs = Variable(torch.stack(inputs)) 124 | labels = Variable(torch.stack(labels)).view(-1) 125 | 126 | model.hidden = model.init_hidden() 127 | output = model(inputs) 128 | 129 | _, predict = torch.max(output, 1) 130 | correct += (predict.data.numpy() == labels.data.numpy()).sum() 131 | total_samples += labels.size()[0] 132 | 133 | truth_res += labels.data.numpy().tolist() 134 | pred_res += predict.data.numpy().tolist() 135 | 136 | acc = correct / total_samples 137 | metrics = precision_recall_fscore_support( 138 | truth_res, pred_res, average='micro') 139 | return acc, metrics 140 | 141 | 142 | def evenly_distribute(X, y): 143 | counts = [0, 0] 144 | 145 | for l in y: 146 | counts[l[0]] += 1 147 | 148 | new_X = [] 149 | new_y = [] 150 | min_count = min(counts[0], counts[1]) 151 | print('Min sample count: {}'.format(min_count)) 152 | 153 | new_counts = [0, 0] 154 | for i in range(0, len(X)): 155 | l = y[i][0] 156 | if new_counts[l] <= min_count: 157 | new_X.append(X[i]) 158 | new_y.append(y[i]) 159 | new_counts[l] += 1 160 | 161 | if new_counts[0] >= min_count and new_counts[1] >= min_count: 162 | break 163 | 164 | return new_X, new_y 165 | 166 | 167 | def lstm(config, embedding_data, code): 168 | X = [row[0] for row in embedding_data] 169 | y = [row[1] for row in embedding_data] 170 | 171 | # Evenly distribute across classes 172 | X, y = evenly_distribute(X, y) 173 | 174 | print('Total samples: {}'.format(len(X))) 175 | 176 | X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2) 177 | 178 | train_dataset = MbtiDataset(X_train, y_train) 179 | train_dataloader = DataLoader( 180 | train_dataset, 181 | batch_size=config.batch_size, 182 | shuffle=True, 183 | num_workers=4) 184 | 185 | test_dataset = MbtiDataset(X_val, y_val) 186 | test_dataloader = DataLoader( 187 | test_dataset, 188 | batch_size=config.batch_size, 189 | shuffle=True, 190 | num_workers=4) 191 | 192 | label_size = 2 193 | EMBEDDING_DIM = config.feature_size 194 | HIDDEN_DIM = 128 195 | best_acc = 0.0 196 | 197 | model = LSTMClassifier( 198 | config, 199 | embedding_dim=EMBEDDING_DIM, 200 | hidden_dim=HIDDEN_DIM, 201 | label_size=label_size) 202 | 203 | parameters = filter(lambda p: p.requires_grad, model.parameters()) 204 | loss_fn = nn.CrossEntropyLoss() 205 | optimizer = optim.Adam(parameters, lr=1e-3) 206 | 207 | losses = [] 208 | train_accs = [] 209 | test_accs = [] 210 | 211 | best_model = None 212 | best_metrics = None 213 | for i in range(config.epochs): 214 | avg_loss = 0.0 215 | 216 | train_loss, train_acc = train_epoch(model, train_dataloader, loss_fn, 217 | optimizer, i) 218 | losses.append(train_loss) 219 | train_accs.append(train_acc) 220 | 221 | acc, metrics = evaluate(model, test_dataloader) 222 | test_accs.append(acc) 223 | 224 | print('Epoch #{} Val Acc: {:.2f}%'.format(i, acc * 100)) 225 | print('') 226 | 227 | if acc > best_acc: 228 | best_acc = acc 229 | best_model = model.state_dict() 230 | best_metrics = metrics 231 | 232 | save_data = { 233 | 'best_acc': best_acc, 234 | 'best_metrics': best_metrics, 235 | 'losses': losses, 236 | 'train_accs': train_accs, 237 | 'test_accs': test_accs, 238 | 'personality_char': code + 1, 239 | 'letters': codes[code] 240 | } 241 | 242 | print('Best Acc: {:.2f}%'.format(best_acc * 100)) 243 | 244 | # Save the best model 245 | torch.save(best_model, 'saves/{}_model'.format(code)) 246 | 247 | # Save the results 248 | filename = 'results/{}_save'.format(code) 249 | with open(filename, 'wb') as f: 250 | pickle.dump(save_data, f, protocol=pickle.HIGHEST_PROTOCOL) 251 | 252 | 253 | def load_model(config, code): 254 | model_file = 'saves/{}_model'.format(code) 255 | model = LSTMClassifier( 256 | config, 257 | embedding_dim=config.feature_size, 258 | hidden_dim=128, 259 | label_size=2) 260 | model.load_state_dict(torch.load(model_file)) 261 | return model 262 | 263 | 264 | if __name__ == '__main__': 265 | config = get_config() 266 | pre_data = pd.read_csv(config.pre_save_file).values 267 | split = int(len(pre_data) * 0.9) 268 | 269 | trainval = pre_data[:split] 270 | test = pre_data[split:] 271 | 272 | # Save trainval and test datasets 273 | with open('trainval_set', 'wb') as f: 274 | pickle.dump(trainval, f, protocol=pickle.HIGHEST_PROTOCOL) 275 | 276 | with open('test_set', 'wb') as f: 277 | pickle.dump(test, f, protocol=pickle.HIGHEST_PROTOCOL) 278 | 279 | for code in [FIRST, SECOND, THIRD, FOURTH]: 280 | embedding_data = word2vec( 281 | config, code=code, batch=False, pre_data=trainval) 282 | lstm(config, embedding_data, code) 283 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # Created by https://www.gitignore.io/api/emacs,macos,python,windows,pycharm,sublimetext,visualstudiocode 3 | 4 | ### Emacs ### 5 | # -*- mode: gitignore; -*- 6 | *~ 7 | \#*\# 8 | /.emacs.desktop 9 | /.emacs.desktop.lock 10 | *.elc 11 | auto-save-list 12 | tramp 13 | .\#* 14 | 15 | # Org-mode 16 | .org-id-locations 17 | *_archive 18 | 19 | # flymake-mode 20 | *_flymake.* 21 | 22 | # eshell files 23 | /eshell/history 24 | /eshell/lastdir 25 | 26 | # elpa packages 27 | /elpa/ 28 | 29 | # reftex files 30 | *.rel 31 | 32 | # AUCTeX auto folder 33 | /auto/ 34 | 35 | # cask packages 36 | .cask/ 37 | dist/ 38 | 39 | # Flycheck 40 | flycheck_*.el 41 | 42 | # server auth directory 43 | /server/ 44 | 45 | # projectiles files 46 | .projectile 47 | projectile-bookmarks.eld 48 | 49 | # directory configuration 50 | .dir-locals.el 51 | 52 | # saveplace 53 | places 54 | 55 | # url cache 56 | url/cache/ 57 | 58 | # cedet 59 | ede-projects.el 60 | 61 | # smex 62 | smex-items 63 | 64 | # company-statistics 65 | company-statistics-cache.el 66 | 67 | # anaconda-mode 68 | anaconda-mode/ 69 | 70 | ### macOS ### 71 | *.DS_Store 72 | .AppleDouble 73 | .LSOverride 74 | 75 | # Icon must end with two \r 76 | Icon 77 | 78 | # Thumbnails 79 | ._* 80 | 81 | # Files that might appear in the root of a volume 82 | .DocumentRevisions-V100 83 | .fseventsd 84 | .Spotlight-V100 85 | .TemporaryItems 86 | .Trashes 87 | .VolumeIcon.icns 88 | .com.apple.timemachine.donotpresent 89 | 90 | # Directories potentially created on remote AFP share 91 | .AppleDB 92 | .AppleDesktop 93 | Network Trash Folder 94 | Temporary Items 95 | .apdisk 96 | 97 | ### PyCharm ### 98 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and Webstorm 99 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 100 | 101 | # User-specific stuff: 102 | .idea/**/workspace.xml 103 | .idea/**/tasks.xml 104 | .idea/dictionaries 105 | 106 | # Sensitive or high-churn files: 107 | .idea/**/dataSources/ 108 | .idea/**/dataSources.ids 109 | .idea/**/dataSources.xml 110 | .idea/**/dataSources.local.xml 111 | .idea/**/sqlDataSources.xml 112 | .idea/**/dynamic.xml 113 | .idea/**/uiDesigner.xml 114 | 115 | # Gradle: 116 | .idea/**/gradle.xml 117 | .idea/**/libraries 118 | 119 | # CMake 120 | cmake-build-debug/ 121 | 122 | # Mongo Explorer plugin: 123 | .idea/**/mongoSettings.xml 124 | 125 | ## File-based project format: 126 | *.iws 127 | 128 | ## Plugin-specific files: 129 | 130 | # IntelliJ 131 | /out/ 132 | 133 | # mpeltonen/sbt-idea plugin 134 | .idea_modules/ 135 | 136 | # JIRA plugin 137 | atlassian-ide-plugin.xml 138 | 139 | # Cursive Clojure plugin 140 | .idea/replstate.xml 141 | 142 | # Ruby plugin and RubyMine 143 | /.rakeTasks 144 | 145 | # Crashlytics plugin (for Android Studio and IntelliJ) 146 | com_crashlytics_export_strings.xml 147 | crashlytics.properties 148 | crashlytics-build.properties 149 | fabric.properties 150 | 151 | ### PyCharm Patch ### 152 | # Comment Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-215987721 153 | 154 | # *.iml 155 | # modules.xml 156 | # .idea/misc.xml 157 | # *.ipr 158 | 159 | # Sonarlint plugin 160 | .idea/sonarlint 161 | 162 | ### Python ### 163 | # Byte-compiled / optimized / DLL files 164 | __pycache__/ 165 | *.py[cod] 166 | *$py.class 167 | 168 | # C extensions 169 | *.so 170 | 171 | # Distribution / packaging 172 | .Python 173 | build/ 174 | develop-eggs/ 175 | downloads/ 176 | eggs/ 177 | .eggs/ 178 | lib/ 179 | lib64/ 180 | parts/ 181 | sdist/ 182 | var/ 183 | wheels/ 184 | *.egg-info/ 185 | .installed.cfg 186 | *.egg 187 | 188 | # PyInstaller 189 | # Usually these files are written by a python script from a template 190 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 191 | *.manifest 192 | *.spec 193 | 194 | # Installer logs 195 | pip-log.txt 196 | pip-delete-this-directory.txt 197 | 198 | # Unit test / coverage reports 199 | htmlcov/ 200 | .tox/ 201 | .coverage 202 | .coverage.* 203 | .cache 204 | .pytest_cache/ 205 | nosetests.xml 206 | coverage.xml 207 | *.cover 208 | .hypothesis/ 209 | 210 | # Translations 211 | *.mo 212 | *.pot 213 | 214 | # Flask stuff: 215 | instance/ 216 | .webassets-cache 217 | 218 | # Scrapy stuff: 219 | .scrapy 220 | 221 | # Sphinx documentation 222 | docs/_build/ 223 | 224 | # PyBuilder 225 | target/ 226 | 227 | # Jupyter Notebook 228 | .ipynb_checkpoints 229 | 230 | # pyenv 231 | .python-version 232 | 233 | # celery beat schedule file 234 | celerybeat-schedule.* 235 | 236 | # SageMath parsed files 237 | *.sage.py 238 | 239 | # Environments 240 | .env 241 | .venv 242 | env/ 243 | venv/ 244 | ENV/ 245 | env.bak/ 246 | venv.bak/ 247 | 248 | # Spyder project settings 249 | .spyderproject 250 | .spyproject 251 | 252 | # Rope project settings 253 | .ropeproject 254 | 255 | # mkdocs documentation 256 | /site 257 | 258 | # mypy 259 | .mypy_cache/ 260 | 261 | ### SublimeText ### 262 | # cache files for sublime text 263 | *.tmlanguage.cache 264 | *.tmPreferences.cache 265 | *.stTheme.cache 266 | 267 | # workspace files are user-specific 268 | *.sublime-workspace 269 | 270 | # project files should be checked into the repository, unless a significant 271 | # proportion of contributors will probably not be using SublimeText 272 | # *.sublime-project 273 | 274 | # sftp configuration file 275 | sftp-config.json 276 | 277 | # Package control specific files 278 | Package Control.last-run 279 | Package Control.ca-list 280 | Package Control.ca-bundle 281 | Package Control.system-ca-bundle 282 | Package Control.cache/ 283 | Package Control.ca-certs/ 284 | Package Control.merged-ca-bundle 285 | Package Control.user-ca-bundle 286 | oscrypto-ca-bundle.crt 287 | bh_unicode_properties.cache 288 | 289 | # Sublime-github package stores a github token in this file 290 | # https://packagecontrol.io/packages/sublime-github 291 | GitHub.sublime-settings 292 | 293 | ### VisualStudioCode ### 294 | .vscode/* 295 | !.vscode/settings.json 296 | !.vscode/tasks.json 297 | !.vscode/launch.json 298 | !.vscode/extensions.json 299 | .history 300 | 301 | ### Windows ### 302 | # Windows thumbnail cache files 303 | Thumbs.db 304 | ehthumbs.db 305 | ehthumbs_vista.db 306 | 307 | # Folder config file 308 | Desktop.ini 309 | 310 | # Recycle Bin used on file shares 311 | $RECYCLE.BIN/ 312 | 313 | # Windows Installer files 314 | *.cab 315 | *.msi 316 | *.msm 317 | *.msp 318 | 319 | # Windows shortcuts 320 | *.lnk 321 | 322 | 323 | # End of https://www.gitignore.io/api/emacs,macos,python,windows,pycharm,sublimetext,visualstudiocode 324 | 325 | # Created by https://www.gitignore.io/api/emacs,macos,python,windows,pycharm,sublimetext,visualstudiocode 326 | 327 | ### Emacs ### 328 | # -*- mode: gitignore; -*- 329 | *~ 330 | \#*\# 331 | /.emacs.desktop 332 | /.emacs.desktop.lock 333 | *.elc 334 | auto-save-list 335 | tramp 336 | .\#* 337 | 338 | # Org-mode 339 | .org-id-locations 340 | *_archive 341 | 342 | # flymake-mode 343 | *_flymake.* 344 | 345 | # eshell files 346 | /eshell/history 347 | /eshell/lastdir 348 | 349 | # elpa packages 350 | /elpa/ 351 | 352 | # reftex files 353 | *.rel 354 | 355 | # AUCTeX auto folder 356 | /auto/ 357 | 358 | # cask packages 359 | .cask/ 360 | dist/ 361 | 362 | # Flycheck 363 | flycheck_*.el 364 | 365 | # server auth directory 366 | /server/ 367 | 368 | # projectiles files 369 | .projectile 370 | projectile-bookmarks.eld 371 | 372 | # directory configuration 373 | .dir-locals.el 374 | 375 | # saveplace 376 | places 377 | 378 | # url cache 379 | url/cache/ 380 | 381 | # cedet 382 | ede-projects.el 383 | 384 | # smex 385 | smex-items 386 | 387 | # company-statistics 388 | company-statistics-cache.el 389 | 390 | # anaconda-mode 391 | anaconda-mode/ 392 | 393 | ### macOS ### 394 | *.DS_Store 395 | .AppleDouble 396 | .LSOverride 397 | 398 | # Icon must end with two \r 399 | Icon 400 | 401 | # Thumbnails 402 | ._* 403 | 404 | # Files that might appear in the root of a volume 405 | .DocumentRevisions-V100 406 | .fseventsd 407 | .Spotlight-V100 408 | .TemporaryItems 409 | .Trashes 410 | .VolumeIcon.icns 411 | .com.apple.timemachine.donotpresent 412 | 413 | # Directories potentially created on remote AFP share 414 | .AppleDB 415 | .AppleDesktop 416 | Network Trash Folder 417 | Temporary Items 418 | .apdisk 419 | 420 | ### PyCharm ### 421 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and Webstorm 422 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 423 | 424 | # User-specific stuff: 425 | .idea/**/workspace.xml 426 | .idea/**/tasks.xml 427 | .idea/dictionaries 428 | 429 | # Sensitive or high-churn files: 430 | .idea/**/dataSources/ 431 | .idea/**/dataSources.ids 432 | .idea/**/dataSources.xml 433 | .idea/**/dataSources.local.xml 434 | .idea/**/sqlDataSources.xml 435 | .idea/**/dynamic.xml 436 | .idea/**/uiDesigner.xml 437 | 438 | # Gradle: 439 | .idea/**/gradle.xml 440 | .idea/**/libraries 441 | 442 | # CMake 443 | cmake-build-debug/ 444 | 445 | # Mongo Explorer plugin: 446 | .idea/**/mongoSettings.xml 447 | 448 | ## File-based project format: 449 | *.iws 450 | 451 | ## Plugin-specific files: 452 | 453 | # IntelliJ 454 | /out/ 455 | 456 | # mpeltonen/sbt-idea plugin 457 | .idea_modules/ 458 | 459 | # JIRA plugin 460 | atlassian-ide-plugin.xml 461 | 462 | # Cursive Clojure plugin 463 | .idea/replstate.xml 464 | 465 | # Ruby plugin and RubyMine 466 | /.rakeTasks 467 | 468 | # Crashlytics plugin (for Android Studio and IntelliJ) 469 | com_crashlytics_export_strings.xml 470 | crashlytics.properties 471 | crashlytics-build.properties 472 | fabric.properties 473 | 474 | ### PyCharm Patch ### 475 | # Comment Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-215987721 476 | 477 | # *.iml 478 | # modules.xml 479 | # .idea/misc.xml 480 | # *.ipr 481 | 482 | # Sonarlint plugin 483 | .idea/sonarlint 484 | 485 | ### Python ### 486 | # Byte-compiled / optimized / DLL files 487 | __pycache__/ 488 | *.py[cod] 489 | *$py.class 490 | 491 | # C extensions 492 | *.so 493 | 494 | # Distribution / packaging 495 | .Python 496 | build/ 497 | develop-eggs/ 498 | downloads/ 499 | eggs/ 500 | .eggs/ 501 | lib/ 502 | lib64/ 503 | parts/ 504 | sdist/ 505 | var/ 506 | wheels/ 507 | *.egg-info/ 508 | .installed.cfg 509 | *.egg 510 | 511 | # PyInstaller 512 | # Usually these files are written by a python script from a template 513 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 514 | *.manifest 515 | *.spec 516 | 517 | # Installer logs 518 | pip-log.txt 519 | pip-delete-this-directory.txt 520 | 521 | # Unit test / coverage reports 522 | htmlcov/ 523 | .tox/ 524 | .coverage 525 | .coverage.* 526 | .cache 527 | .pytest_cache/ 528 | nosetests.xml 529 | coverage.xml 530 | *.cover 531 | .hypothesis/ 532 | 533 | # Translations 534 | *.mo 535 | *.pot 536 | 537 | # Flask stuff: 538 | instance/ 539 | .webassets-cache 540 | 541 | # Scrapy stuff: 542 | .scrapy 543 | 544 | # Sphinx documentation 545 | docs/_build/ 546 | 547 | # PyBuilder 548 | target/ 549 | 550 | # Jupyter Notebook 551 | .ipynb_checkpoints 552 | 553 | # pyenv 554 | .python-version 555 | 556 | # celery beat schedule file 557 | celerybeat-schedule.* 558 | 559 | # SageMath parsed files 560 | *.sage.py 561 | 562 | # Environments 563 | .env 564 | .venv 565 | env/ 566 | venv/ 567 | ENV/ 568 | env.bak/ 569 | venv.bak/ 570 | 571 | # Spyder project settings 572 | .spyderproject 573 | .spyproject 574 | 575 | # Rope project settings 576 | .ropeproject 577 | 578 | # mkdocs documentation 579 | /site 580 | 581 | # mypy 582 | .mypy_cache/ 583 | 584 | ### SublimeText ### 585 | # cache files for sublime text 586 | *.tmlanguage.cache 587 | *.tmPreferences.cache 588 | *.stTheme.cache 589 | 590 | # workspace files are user-specific 591 | *.sublime-workspace 592 | 593 | # project files should be checked into the repository, unless a significant 594 | # proportion of contributors will probably not be using SublimeText 595 | # *.sublime-project 596 | 597 | # sftp configuration file 598 | sftp-config.json 599 | 600 | # Package control specific files 601 | Package Control.last-run 602 | Package Control.ca-list 603 | Package Control.ca-bundle 604 | Package Control.system-ca-bundle 605 | Package Control.cache/ 606 | Package Control.ca-certs/ 607 | Package Control.merged-ca-bundle 608 | Package Control.user-ca-bundle 609 | oscrypto-ca-bundle.crt 610 | bh_unicode_properties.cache 611 | 612 | # Sublime-github package stores a github token in this file 613 | # https://packagecontrol.io/packages/sublime-github 614 | GitHub.sublime-settings 615 | 616 | ### VisualStudioCode ### 617 | .vscode/* 618 | !.vscode/settings.json 619 | !.vscode/tasks.json 620 | !.vscode/launch.json 621 | !.vscode/extensions.json 622 | .history 623 | 624 | ### Windows ### 625 | # Windows thumbnail cache files 626 | Thumbs.db 627 | ehthumbs.db 628 | ehthumbs_vista.db 629 | 630 | # Folder config file 631 | Desktop.ini 632 | 633 | # Recycle Bin used on file shares 634 | $RECYCLE.BIN/ 635 | 636 | # Windows Installer files 637 | *.cab 638 | *.msi 639 | *.msm 640 | *.msp 641 | 642 | # Windows shortcuts 643 | *.lnk 644 | 645 | 646 | # End of https://www.gitignore.io/api/emacs,macos,python,windows,pycharm,sublimetext,visualstudiocode 647 | 648 | # Datasets 649 | dataset.zip 650 | *.csv 651 | data/ 652 | test.png 653 | results/* 654 | saves/* 655 | saves* 656 | -------------------------------------------------------------------------------- /preprocess.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Load MBTI Dataset\n", 8 | "\n", 9 | "You should extracted the zip file and have a csv file" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 1, 15 | "metadata": { 16 | "collapsed": true 17 | }, 18 | "outputs": [], 19 | "source": [ 20 | "filename = 'mbti_1.csv'\n", 21 | "outfilename = 'mbti_preprocessed.csv'" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 2, 27 | "metadata": { 28 | "collapsed": true 29 | }, 30 | "outputs": [], 31 | "source": [ 32 | "import numpy as np\n", 33 | "import pandas as pd\n", 34 | "import re\n", 35 | "import string" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": 3, 41 | "metadata": {}, 42 | "outputs": [ 43 | { 44 | "data": { 45 | "text/html": [ 46 | "
\n", 47 | "\n", 60 | "\n", 61 | " \n", 62 | " \n", 63 | " \n", 64 | " \n", 65 | " \n", 66 | " \n", 67 | " \n", 68 | " \n", 69 | " \n", 70 | " \n", 71 | " \n", 72 | " \n", 73 | " \n", 74 | " \n", 75 | " \n", 76 | " \n", 77 | " \n", 78 | " \n", 79 | " \n", 80 | " \n", 81 | " \n", 82 | " \n", 83 | " \n", 84 | " \n", 85 | " \n", 86 | " \n", 87 | " \n", 88 | " \n", 89 | " \n", 90 | " \n", 91 | " \n", 92 | " \n", 93 | " \n", 94 | " \n", 95 | " \n", 96 | " \n", 97 | " \n", 98 | " \n", 99 | " \n", 100 | " \n", 101 | " \n", 102 | " \n", 103 | " \n", 104 | " \n", 105 | " \n", 106 | " \n", 107 | " \n", 108 | " \n", 109 | " \n", 110 | " \n", 111 | " \n", 112 | " \n", 113 | " \n", 114 | " \n", 115 | " \n", 116 | " \n", 117 | " \n", 118 | " \n", 119 | " \n", 120 | "
typeposts
0INFJ'http://www.youtube.com/watch?v=qsXHcwe3krw|||...
1ENTP'I'm finding the lack of me in these posts ver...
2INTP'Good one _____ https://www.youtube.com/wat...
3INTJ'Dear INTP, I enjoyed our conversation the o...
4ENTJ'You're fired.|||That's another silly misconce...
5INTJ'18/37 @.@|||Science is not perfect. No scien...
6INFJ'No, I can't draw on my own nails (haha). Thos...
7INTJ'I tend to build up a collection of things on ...
8INFJI'm not sure, that's a good question. The dist...
9INTP'https://www.youtube.com/watch?v=w8-egj0y8Qs||...
\n", 121 | "
" 122 | ], 123 | "text/plain": [ 124 | " type posts\n", 125 | "0 INFJ 'http://www.youtube.com/watch?v=qsXHcwe3krw|||...\n", 126 | "1 ENTP 'I'm finding the lack of me in these posts ver...\n", 127 | "2 INTP 'Good one _____ https://www.youtube.com/wat...\n", 128 | "3 INTJ 'Dear INTP, I enjoyed our conversation the o...\n", 129 | "4 ENTJ 'You're fired.|||That's another silly misconce...\n", 130 | "5 INTJ '18/37 @.@|||Science is not perfect. No scien...\n", 131 | "6 INFJ 'No, I can't draw on my own nails (haha). Thos...\n", 132 | "7 INTJ 'I tend to build up a collection of things on ...\n", 133 | "8 INFJ I'm not sure, that's a good question. The dist...\n", 134 | "9 INTP 'https://www.youtube.com/watch?v=w8-egj0y8Qs||..." 135 | ] 136 | }, 137 | "execution_count": 3, 138 | "metadata": {}, 139 | "output_type": "execute_result" 140 | } 141 | ], 142 | "source": [ 143 | "df_ = pd.read_csv(filename)\n", 144 | "\n", 145 | "# Preview the rows\n", 146 | "df_.head(10)" 147 | ] 148 | }, 149 | { 150 | "cell_type": "markdown", 151 | "metadata": {}, 152 | "source": [ 153 | "# Preprocess\n", 154 | "\n", 155 | "## We need to split the posts on the ||| string and create a new row with the same type" 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "execution_count": 4, 161 | "metadata": { 162 | "collapsed": true 163 | }, 164 | "outputs": [], 165 | "source": [ 166 | "newrows = []\n", 167 | "\n", 168 | "def filter_text(post):\n", 169 | " \"\"\"Decide whether or not we want to use the post.\"\"\"\n", 170 | " # should remove link only posts here\n", 171 | " return len(post) > 0\n", 172 | " \n", 173 | "reg_punc = re.compile('[%s]' % re.escape(string.punctuation))\n", 174 | "def preprocess_text(post):\n", 175 | " \"\"\"Remove any junk we don't want to use in the post.\"\"\"\n", 176 | " \n", 177 | " # Remove links\n", 178 | " post = re.sub(r'http\\S+', '', post, flags=re.MULTILINE)\n", 179 | " \n", 180 | " # All lowercase\n", 181 | " post = post.lower()\n", 182 | " \n", 183 | " # Remove puncutation\n", 184 | " post = reg_punc.sub('', post)\n", 185 | " \n", 186 | " return post\n", 187 | "\n", 188 | "def create_new_rows(row):\n", 189 | " posts = row['posts'].split('|||')\n", 190 | " rows = []\n", 191 | " \n", 192 | " for p in posts:\n", 193 | " p = preprocess_text(p)\n", 194 | " if not filter_text(p):\n", 195 | " continue\n", 196 | " rows.append({'type': row['type'], 'post': p})\n", 197 | " return rows\n", 198 | "\n", 199 | "for index, row in df_.iterrows():\n", 200 | " newrows += create_new_rows(row)\n", 201 | " \n", 202 | "df = pd.DataFrame(newrows)\n", 203 | "unique = df.groupby('type').nunique()" 204 | ] 205 | }, 206 | { 207 | "cell_type": "code", 208 | "execution_count": 5, 209 | "metadata": {}, 210 | "outputs": [ 211 | { 212 | "name": "stdout", 213 | "output_type": "stream", 214 | "text": [ 215 | "411495 rows\n" 216 | ] 217 | }, 218 | { 219 | "data": { 220 | "text/html": [ 221 | "
\n", 222 | "\n", 235 | "\n", 236 | " \n", 237 | " \n", 238 | " \n", 239 | " \n", 240 | " \n", 241 | " \n", 242 | " \n", 243 | " \n", 244 | " \n", 245 | " \n", 246 | " \n", 247 | " \n", 248 | " \n", 249 | " \n", 250 | " \n", 251 | " \n", 252 | " \n", 253 | " \n", 254 | " \n", 255 | " \n", 256 | " \n", 257 | " \n", 258 | " \n", 259 | " \n", 260 | " \n", 261 | " \n", 262 | " \n", 263 | " \n", 264 | " \n", 265 | " \n", 266 | " \n", 267 | " \n", 268 | " \n", 269 | " \n", 270 | " \n", 271 | " \n", 272 | " \n", 273 | " \n", 274 | " \n", 275 | " \n", 276 | " \n", 277 | " \n", 278 | " \n", 279 | " \n", 280 | " \n", 281 | " \n", 282 | " \n", 283 | " \n", 284 | " \n", 285 | " \n", 286 | " \n", 287 | " \n", 288 | " \n", 289 | " \n", 290 | " \n", 291 | " \n", 292 | " \n", 293 | " \n", 294 | " \n", 295 | "
posttype
0enfp and intj moments sportscenter not top ...INFJ
1what has been the most lifechanging experience...INFJ
2on repeat for most of todayINFJ
3may the perc experience immerse youINFJ
4the last thing my infj friend posted on his fa...INFJ
5hello enfj7 sorry to hear of your distress its...INFJ
684389 84390INFJ
7welcome and stuffINFJ
8game set matchINFJ
9prozac wellbrutin at least thirty minutes of m...INFJ
\n", 296 | "
" 297 | ], 298 | "text/plain": [ 299 | " post type\n", 300 | "0 enfp and intj moments sportscenter not top ... INFJ\n", 301 | "1 what has been the most lifechanging experience... INFJ\n", 302 | "2 on repeat for most of today INFJ\n", 303 | "3 may the perc experience immerse you INFJ\n", 304 | "4 the last thing my infj friend posted on his fa... INFJ\n", 305 | "5 hello enfj7 sorry to hear of your distress its... INFJ\n", 306 | "6 84389 84390 INFJ\n", 307 | "7 welcome and stuff INFJ\n", 308 | "8 game set match INFJ\n", 309 | "9 prozac wellbrutin at least thirty minutes of m... INFJ" 310 | ] 311 | }, 312 | "execution_count": 5, 313 | "metadata": {}, 314 | "output_type": "execute_result" 315 | } 316 | ], 317 | "source": [ 318 | "print('{} rows'.format(df.shape[0]))\n", 319 | "\n", 320 | "# Preview the data\n", 321 | "df.head(10)" 322 | ] 323 | }, 324 | { 325 | "cell_type": "code", 326 | "execution_count": 6, 327 | "metadata": {}, 328 | "outputs": [ 329 | { 330 | "data": { 331 | "text/html": [ 332 | "
\n", 333 | "\n", 346 | "\n", 347 | " \n", 348 | " \n", 349 | " \n", 350 | " \n", 351 | " \n", 352 | " \n", 353 | " \n", 354 | " \n", 355 | " \n", 356 | " \n", 357 | " \n", 358 | " \n", 359 | " \n", 360 | " \n", 361 | " \n", 362 | " \n", 363 | " \n", 364 | " \n", 365 | " \n", 366 | " \n", 367 | " \n", 368 | " \n", 369 | " \n", 370 | " \n", 371 | " \n", 372 | " \n", 373 | " \n", 374 | " \n", 375 | " \n", 376 | " \n", 377 | " \n", 378 | " \n", 379 | " \n", 380 | " \n", 381 | " \n", 382 | " \n", 383 | " \n", 384 | " \n", 385 | " \n", 386 | " \n", 387 | " \n", 388 | " \n", 389 | " \n", 390 | " \n", 391 | " \n", 392 | " \n", 393 | " \n", 394 | " \n", 395 | " \n", 396 | " \n", 397 | " \n", 398 | " \n", 399 | " \n", 400 | " \n", 401 | " \n", 402 | " \n", 403 | " \n", 404 | " \n", 405 | " \n", 406 | " \n", 407 | " \n", 408 | " \n", 409 | " \n", 410 | " \n", 411 | " \n", 412 | " \n", 413 | " \n", 414 | " \n", 415 | " \n", 416 | " \n", 417 | " \n", 418 | " \n", 419 | " \n", 420 | " \n", 421 | " \n", 422 | " \n", 423 | " \n", 424 | " \n", 425 | " \n", 426 | " \n", 427 | " \n", 428 | " \n", 429 | " \n", 430 | " \n", 431 | " \n", 432 | " \n", 433 | " \n", 434 | " \n", 435 | " \n", 436 | " \n", 437 | " \n", 438 | " \n", 439 | " \n", 440 | " \n", 441 | "
posttype
type
INFP859361
INFJ692991
INTP608451
INTJ505181
ENTP327311
ENFP317941
ISTP158091
ISFP122891
ENTJ109071
ISTJ95591
ENFJ90471
ISFJ78071
ESTP42061
ESFP21331
ESFJ19851
ESTJ18701
\n", 442 | "
" 443 | ], 444 | "text/plain": [ 445 | " post type\n", 446 | "type \n", 447 | "INFP 85936 1\n", 448 | "INFJ 69299 1\n", 449 | "INTP 60845 1\n", 450 | "INTJ 50518 1\n", 451 | "ENTP 32731 1\n", 452 | "ENFP 31794 1\n", 453 | "ISTP 15809 1\n", 454 | "ISFP 12289 1\n", 455 | "ENTJ 10907 1\n", 456 | "ISTJ 9559 1\n", 457 | "ENFJ 9047 1\n", 458 | "ISFJ 7807 1\n", 459 | "ESTP 4206 1\n", 460 | "ESFP 2133 1\n", 461 | "ESFJ 1985 1\n", 462 | "ESTJ 1870 1" 463 | ] 464 | }, 465 | "execution_count": 6, 466 | "metadata": {}, 467 | "output_type": "execute_result" 468 | } 469 | ], 470 | "source": [ 471 | "unique.sort_values(by=['post'], ascending=False)" 472 | ] 473 | }, 474 | { 475 | "cell_type": "markdown", 476 | "metadata": {}, 477 | "source": [ 478 | "# Save preprocessed data to csv" 479 | ] 480 | }, 481 | { 482 | "cell_type": "code", 483 | "execution_count": 78, 484 | "metadata": { 485 | "collapsed": true 486 | }, 487 | "outputs": [], 488 | "source": [ 489 | "df.to_csv(outfilename)" 490 | ] 491 | }, 492 | { 493 | "cell_type": "code", 494 | "execution_count": null, 495 | "metadata": { 496 | "collapsed": true 497 | }, 498 | "outputs": [], 499 | "source": [] 500 | } 501 | ], 502 | "metadata": { 503 | "kernelspec": { 504 | "display_name": "Python 3", 505 | "language": "python", 506 | "name": "python3" 507 | }, 508 | "language_info": { 509 | "codemirror_mode": { 510 | "name": "ipython", 511 | "version": 3 512 | }, 513 | "file_extension": ".py", 514 | "mimetype": "text/x-python", 515 | "name": "python", 516 | "nbconvert_exporter": "python", 517 | "pygments_lexer": "ipython3", 518 | "version": "3.6.3" 519 | } 520 | }, 521 | "nbformat": 4, 522 | "nbformat_minor": 2 523 | } 524 | --------------------------------------------------------------------------------