├── README.md ├── data └── example.gif ├── main.py └── requirements.txt /README.md: -------------------------------------------------------------------------------- 1 | # Interactive Web App with Streamlit and Scikit-learn 2 | Explore different datasets and classifier. This tutorial should demonstrate how easy interactive web applications can be build with *streamlit*. Streamlit lets you create apps for your machine learning projects with simple Python scripts. See official [streamlit website](https://www.streamlit.io/) for more info. 3 | 4 | ## Preview 5 | ![Example of Streamlit|635x380](data/example.gif) 6 | 7 | ## Watch the Tutorial 8 | [![Alt text](https://img.youtube.com/vi/Klqn--Mu2pE/hqdefault.jpg)](https://www.youtube.com/watch?v=Klqn--Mu2pE) 9 | 10 | ## Installation 11 | You need these dependencies: 12 | ```console 13 | pip install streamlit 14 | pip install scikit-learn 15 | pip install matplotlib 16 | ``` 17 | 18 | ## Usage 19 | Run 20 | ```console 21 | streamlit run main.py 22 | ``` 23 | -------------------------------------------------------------------------------- /data/example.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/patrickloeber/streamlit-demo/dd512ff9b8d6c407c02652053085749b39b70c66/data/example.gif -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | import numpy as np 3 | 4 | import matplotlib.pyplot as plt 5 | from sklearn import datasets 6 | from sklearn.model_selection import train_test_split 7 | 8 | from sklearn.decomposition import PCA 9 | from sklearn.svm import SVC 10 | from sklearn.neighbors import KNeighborsClassifier 11 | from sklearn.ensemble import RandomForestClassifier 12 | 13 | from sklearn.metrics import accuracy_score 14 | 15 | st.title('Streamlit Example') 16 | 17 | st.write(""" 18 | # Explore different classifier and datasets 19 | Which one is the best? 20 | """) 21 | 22 | dataset_name = st.sidebar.selectbox( 23 | 'Select Dataset', 24 | ('Iris', 'Breast Cancer', 'Wine') 25 | ) 26 | 27 | st.write(f"## {dataset_name} Dataset") 28 | 29 | classifier_name = st.sidebar.selectbox( 30 | 'Select classifier', 31 | ('KNN', 'SVM', 'Random Forest') 32 | ) 33 | 34 | def get_dataset(name): 35 | data = None 36 | if name == 'Iris': 37 | data = datasets.load_iris() 38 | elif name == 'Wine': 39 | data = datasets.load_wine() 40 | else: 41 | data = datasets.load_breast_cancer() 42 | X = data.data 43 | y = data.target 44 | return X, y 45 | 46 | X, y = get_dataset(dataset_name) 47 | st.write('Shape of dataset:', X.shape) 48 | st.write('number of classes:', len(np.unique(y))) 49 | 50 | def add_parameter_ui(clf_name): 51 | params = dict() 52 | if clf_name == 'SVM': 53 | C = st.sidebar.slider('C', 0.01, 10.0) 54 | params['C'] = C 55 | elif clf_name == 'KNN': 56 | K = st.sidebar.slider('K', 1, 15) 57 | params['K'] = K 58 | else: 59 | max_depth = st.sidebar.slider('max_depth', 2, 15) 60 | params['max_depth'] = max_depth 61 | n_estimators = st.sidebar.slider('n_estimators', 1, 100) 62 | params['n_estimators'] = n_estimators 63 | return params 64 | 65 | params = add_parameter_ui(classifier_name) 66 | 67 | def get_classifier(clf_name, params): 68 | clf = None 69 | if clf_name == 'SVM': 70 | clf = SVC(C=params['C']) 71 | elif clf_name == 'KNN': 72 | clf = KNeighborsClassifier(n_neighbors=params['K']) 73 | else: 74 | clf = clf = RandomForestClassifier(n_estimators=params['n_estimators'], 75 | max_depth=params['max_depth'], random_state=1234) 76 | return clf 77 | 78 | clf = get_classifier(classifier_name, params) 79 | #### CLASSIFICATION #### 80 | 81 | X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=1234) 82 | 83 | clf.fit(X_train, y_train) 84 | y_pred = clf.predict(X_test) 85 | 86 | acc = accuracy_score(y_test, y_pred) 87 | 88 | st.write(f'Classifier = {classifier_name}') 89 | st.write(f'Accuracy =', acc) 90 | 91 | #### PLOT DATASET #### 92 | # Project the data onto the 2 primary principal components 93 | pca = PCA(2) 94 | X_projected = pca.fit_transform(X) 95 | 96 | x1 = X_projected[:, 0] 97 | x2 = X_projected[:, 1] 98 | 99 | fig = plt.figure() 100 | plt.scatter(x1, x2, 101 | c=y, alpha=0.8, 102 | cmap='viridis') 103 | 104 | plt.xlabel('Principal Component 1') 105 | plt.ylabel('Principal Component 2') 106 | plt.colorbar() 107 | 108 | #plt.show() 109 | st.pyplot(fig) 110 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | streamlit 2 | matplotlib 3 | scikit-learn 4 | --------------------------------------------------------------------------------