├── .gitignore ├── README.md ├── image ├── 1 ├── CodeCogsEqn (1).gif ├── adagrad1.png ├── adaptivelr.png ├── derivative.png ├── derivative1.png ├── derivative2.png ├── imbd.png ├── lr_large.png ├── lr_small.png ├── mnist.png ├── pseudocode1.png └── pseudocode2.png └── mashroom ├── cs677-project.pptx ├── cs677project.ipynb └── read.me /.gitignore: -------------------------------------------------------------------------------- 1 | Mercury/ 2 | Mercury.modules 3 | *.mh 4 | *.err 5 | *.init 6 | *.dll 7 | *.exe 8 | *.a 9 | *.so 10 | *.dylib 11 | *.beams 12 | *.d 13 | *.c_date 14 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Machine Learning 2 | 3 | ***Explained machine learning models and projects based on ML*** 4 | 5 | Basic machine learning algorithms: 6 | 7 | 1.**Linear regression** 8 | 9 | 2.**Logistic regression** 10 | 11 | 3.**Support Vector Machines** 12 | 13 | 4.**Naive Bayes** 14 | 15 | *(below coming soon...)* 16 | 17 | 5.Decision trees 18 | 19 | 6.Random Forest 20 | 21 | 7.k- Means clustering 22 | 23 | 8.k-Nearest neighbors 24 | 25 | 26 | 27 | ## Linear regression 28 | 29 | Linear regression is very simple and basic.First, linear regression is supervised model, which means data should be labelled.Linear regression will find the relationships between features(x1,x2,x3....), which represent as coefficients of these variables. 30 | 31 | 32 | ### simple linear regression 33 | 34 | 35 | Let's look at a simple linear regression equation: 36 | 37 | ![equation](https://latex.codecogs.com/gif.latex?y%20%3D%20%5CTheta%20_1x+%5CTheta%20_0) 38 | 39 | ▷ θ1 is the coefficient of the independent variable (slope) 40 | 41 | ▷ θ0 is the constant term or the y intercept. 42 | 43 | 44 | Then consider this dataset as tuples of (1, 18), (2, 22), (3, 45), (4, 49), (5, 86) 45 | 46 | • We might want to fit a straight line to the given data 47 | 48 | • Assume to fit a line with the equation Y = θ1X + θ0 49 | 50 | • Our goal is to minimize errors 51 | 52 | To minimize the amount of distance(errors), we need to find proper θ1 and θ0.We build a function ,which often referred to as a ***lost function***,. 53 | 54 | lost function has three common formula: 55 | (1)MSE(Mean Squared Error) 56 | (2)RMSE(Root Mean Squared Error) 57 | (3)Logloss(Cross Entorpy loss) 58 | 59 | 60 | In this case, we choose **Mean Squared Error**. 61 | 62 | ![equation](https://latex.codecogs.com/gif.latex?l%28%5CTheta%20_0%2C%5CTheta%20_1%29%20%3D%20%5Cfrac%7B1%7D%7B2n%7D%5Csum_%7Bi%7D%20%28f%28x_%7Bi%7D%29-y_%7Bi%7D%29%5E%7B2%7D) 63 | 64 | So when we want to fit a line to given data, we need to minimize the lost function. 65 | 66 | then, computing lost function: 67 | 68 | ![equation](https://latex.codecogs.com/gif.latex?%5Cbegin%7Balign%7D%20l%28%5CTheta%20_0%2C%5CTheta%20_1%29%20%26%20%3D%20%5Cfrac%7B1%7D%7B2n%7D%20%5Csum_%7Bi%7D%20%28f%28x_%7Bi%7D%29-y_%7Bi%7D%29%5E%7B2%7D%20%5Cnonumber%20%5C%5C%20%26%3D%20%5Cfrac%7B1%7D%7B2n%7D%20%5Csum_%7Bi%7D%28%5CTheta%20_1%20x_i%20+%20%5CTheta%20_1%20-%20y_%7Bi%7D%29%5E%7B2%7D%5Cnonumber%20%5Cend%7Balign%7D) 69 | 70 | 71 | 72 | 73 | ![equation](https://latex.codecogs.com/gif.latex?%5Cbegin%7Balign%7D%20l%7B%7D%27%28%5CTheta%20_0%2C%5CTheta%20_1%29%20%26%20%3D%20%5Cfrac%7B1%7D%7Bn%7D%20%5Csum_%7Bi%7D%20%28%5CTheta%20_1%20x_%7Bi%7D%20%5E%7B2%7D%20-x_%7Bi%7Dy_%7Bi%7D%29%20%5Cnonumber%20%5C%5C%20%26%3D%20%5Cfrac%7B1%7D%7Bn%7D%20%5Csum_%7Bi%7D%282%5CTheta%20_1%281%20+%204%20+%209%20+%2016%20+%2025%29%20-%202%2818%20+%2044%20+%20135%20+%20196%20+%20430%29%29%5Cnonumber%20%5C%5C%20%26%3D%20%5Cfrac%7B1%7D%7B5%7D%28110m%20-%201646%29%5Cnonumber%20%5Cend%7Balign%7D) 74 | 75 | Since cost function is a ”Convex” function, when its derivative is 0, the cost function hits bottom. 76 | So loss minimized at θ = 14.96. 77 | 78 | Now we have a polynomial linear regression, suppose each entity x has d dimensions: 79 | 80 | ![equation](https://latex.codecogs.com/gif.latex?y%20%3D%20%5CTheta_%7B0%7D%20+%20%5CTheta_%7B1%7Dx_1%20+%20...%20+%20%5CTheta_%7Bd%7Dx_d) 81 | 82 | Similarly, we get the lost function : 83 | 84 | ![equation](https://latex.codecogs.com/gif.latex?%5Cbegin%7Balign%7D%20l%28%5CTheta%20_0%2C%5CTheta%20_1...%5CTheta%20_d%29%20%26%20%3D%20%5Cfrac%7B1%7D%7B2n%7D%20%5Csum_%7Bi%7D%20%28f%28x_%7Bi%7D%29%20-%20y_%7Bi%7D%29%5E2%20%5Cnonumber%20%5Cend%7Balign%7D) 85 | 86 | So in order to minimize the cost function, we need to choose each θi to minimize l(θ0,θ1...),this is what we called ***Gradient Descent***. 87 | 88 | Gradient Descent is an iterative algorithm,Start from an initial guess and try to incrementally improve current solution,and at iteration step θ(iter) is the current guess for θi. 89 | 90 | 91 | #### How to calculate gradient 92 | 93 | Suppose ▽l(θ) is a vector whose ith entry is ith partial derivative evaluated at θi 94 | 95 | ![equation](https://latex.codecogs.com/gif.latex?%5Cbegin%7Balign%7D%20%5CDelta%20l%28%5Ctheta%29%20%3D%20%5Cbegin%7Bbmatrix%7D%5Cnonumber%20%5Cfrac%7B%5Cpartial%20l%28%5Ctheta%29%7D%7B%5Cpartial%20%5Ctheta%20_0%7D%5C%5C%20%5Cfrac%7B%5Cpartial%20l%28%5Ctheta%29%7D%7B%5Cpartial%20%5Ctheta%20_1%7D%5C%5C%20.%5C%5C%20.%5C%5C%20%5Cfrac%7B%5Cpartial%20l%28%5Ctheta%29%7D%7B%5Cpartial%20%5Ctheta%20_d%7D%5C%5C%20%5Cend%7Bbmatrix%7D%20%5Cend%7Balign%7D) 96 | 97 | 98 | In privious sessions, we got the loss function, which is 99 | 100 | 101 | ![equation](https://latex.codecogs.com/gif.latex?%5Cbegin%7Balign%7D%20l%28%5CTheta%20_0%2C%5CTheta%20_1...%5CTheta%20_d%29%20%26%20%3D%20%5Cfrac%7B1%7D%7B2n%7D%20%5Csum_%7Bi%7D%20%28f%28x_%7Bi%7D%29%20-%20y_%7Bi%7D%29%5E2%20%5Cnonumber%20%5Cend%7Balign%7D) 102 | 103 | then do expansion: 104 | 105 | 106 | ![equation](https://latex.codecogs.com/gif.latex?l%28%5Ctheta%20_0%2C%5Ctheta%20_1%2C...%2C%5Ctheta%20_d%29%3D%20%5Cfrac%7B1%7D%7B2n%7D%5Csum_%7Bi%7D%5E%7B%7D%28y%5E%7B%28i%29%7D-%20%28%5Ctheta%20_0+%20%5Ctheta%20_1x_1%5E%7B%28i%29%7D+...%5Ctheta%20_dx_d%5E%7B%28i%29%7D%29%29) 107 | 108 | Since it has mutiple dimensions,we compute partial derivatives: 109 | 110 | ![equation](https://latex.codecogs.com/gif.latex?%5Cbegin%7Balign%7D%20%5Cfrac%7B%5Cpartial%20l%28%5Ctheta%29%7D%7B%5Cpartial%20%5Ctheta%20_1%7D%20%3D%20-%20%5Cfrac%7B1%7D%7Bn%7D%5Csum_%7B1%7D%5E%7Bn%7D%20x_%7B1%7D%5E%7B%28i%29%7D%28y%5E%7B%28i%29%7D%20-%20%28%5Ctheta%20_0+%5Ctheta%20_1x_1%5E%7B%28i%29%7D%20+%20...+%20%5Ctheta%20_dx_d%5E%7B%28i%29%7D%20%29%5Cnonumber%5C%5C%20%5Cfrac%7B%5Cpartial%20l%28%5Ctheta%29%7D%7B%5Cpartial%20%5Ctheta%20_2%7D%20%3D%20-%20%5Cfrac%7B1%7D%7Bn%7D%5Csum_%7B1%7D%5E%7Bn%7D%20x_%7B2%7D%5E%7B%28i%29%7D%28y%5E%7B%28i%29%7D%20-%20%28%5Ctheta%20_0+%5Ctheta%20_1x_1%5E%7B%28i%29%7D%20+%20...+%20%5Ctheta%20_dx_d%5E%7B%28i%29%7D%20%29%5Cnonumber%5C%5C%20...%5Cnonumber%5C%5C%20%5Cfrac%7B%5Cpartial%20l%28%5Ctheta%29%7D%7B%5Cpartial%20%5Ctheta%20_d%7D%20%3D%20-%20%5Cfrac%7B1%7D%7Bn%7D%5Csum_%7B1%7D%5E%7Bn%7D%20x_%7Bd%7D%5E%7B%28i%29%7D%28y%5E%7B%28i%29%7D%20-%20%28%5Ctheta%20_0+%5Ctheta%20_1x_1%5E%7B%28i%29%7D%20+%20...+%20%5Ctheta%20_dx_d%5E%7B%28i%29%7D%20%29%5Cnonumber%20%5Cend%7Balign%7D) 111 | 112 | Now we can compute components of the gradients and then sum them up and update weights in the next iteration. 113 | 114 | #### Gradient Descent pseudocode** 115 | 116 | ![pseudocode](https://github.com/gnayoaixgnaw/machine_learning_project/blob/main/image/pseudocode1.png) 117 | 118 | • Here λ is the ”learning rate” and controls speed of convergence 119 | • ▽l(θ iter) is the gradient of L evaluated at iteration ”iter” with parameter of qiter 120 | • Stop conditions can be different 121 | 122 | **When to stop** 123 | 124 | Stop condition can be different, for example: 125 | • Maximum number of iteration is reached (iter < MaxIteration) 126 | • Gradient ▽l(θ iter ) or parameters are not changing (||θ(iter+1) - θ(iter)|| < precisionValue) 127 | • Cost is not decreasing (||l(θ(iter+1)) - L(θ(iter))|| < precisionValue) 128 | • Combination of the above 129 | 130 | more detailed pseudocode to compute gradient: 131 | 132 | // initialize parameters 133 | iteration = 0 134 | learning Rate = 0.01 135 | numIteration = X 136 | theta = np.random.normal(0, 0.1, d) 137 | 138 | while iteration < maxNumIteration: 139 | 140 | calculate gradients 141 | //update parameters 142 | theta -= learning Rate*gradients 143 | iteration+=1 144 | 145 | 146 | 147 | #### Implement code via Pyspark 148 | 149 | ***Check [here](https://github.com/gnayoaixgnaw/Big_Data_Analytics/tree/main/assignment3)*** 150 | 151 | 152 | 153 | ## Regulation in lost function 154 | 155 | We will always face **over-fitting issue** in real problem. **over-fitting issue** is that the parameters of model are large and model's rebustness is poor, which means a little change of test data may cause a huge difference in result.So in order to aviod over-fitting, 156 | 157 | ### l1 norm 158 | 159 | We need to remove parameters which have little contribution and generate sparse matrix, that is, the l1 norm( mean absolute error): 160 | 161 | ![equation](https://latex.codecogs.com/gif.latex?l_1%20%3D%20l+%5Clambda%20%5Csum_%7Bi%3D1%7D%5E%7Bd%7D%5Cleft%20%7C%20%5Ctheta%20_i%20%5Cright%20%7C) 162 | 163 | where l is lost function, ∑ i|θi| is l1 regularizers, λ is regularization coefficient, θi is parameters. 164 | we can visualize l1 lost function: 165 | 166 | ![l1](https://i.loli.net/2018/11/28/5bfe89e366bba.jpg) 167 | 168 | The contour line in the figure is that of l, and the black square is the graph of L1 function. The place where the contour line of l intersects the graph of L1 for the first time is the optimal solution. It is easy to find that the black square must intersect the contour line at the vertex of the square first. l is much more likely to contact those angles than it is to contact any other part. Some dimensions of these points are 0 which will make some features equal to 0 and generate a sparse matrix, which can then be used for feature selection. 169 | 170 | ### l2 norm 171 | 172 | We can make parameters as little as possible by implement l2 norm: 173 | 174 | ![equation](https://latex.codecogs.com/gif.latex?l_1%20%3D%20l+%5Clambda%20%5Csum_%7Bi%3D1%7D%5E%7Bd%7D%5Cleft%20%7C%20%5Ctheta%20_i%20%5Cright%20%7C%5E%7B2%7D) 175 | 176 | 177 | where l is lost function, ∑ i|θi|² is l2 regularizers, λ is regularization coefficient, θi is parameters. 178 | we can visualize l2 lost function: 179 | 180 | ![l2](https://i.loli.net/2018/11/28/5bfe89e366bba.jpg) 181 | 182 | In comparison with the iterative formula without adding L2 regularization, parameters are multiplied by a factor less than 1 in each iteration, which makes parameters decrease continuously. Therefore, in general, parameters decreasing continuously. 183 | 184 | 185 | 186 | 187 | ## Logistic regression 188 | 189 | Logistic regression is supervised model especially for prediction problem.It has binary-class lr and multi-class lr. 190 | 191 | 192 | ### Binary-class logistic regression 193 | 194 | 195 | Suppose we have a prediction problem.It is natural to assume that output y (0/1) given the independent variable(s) X ,which has d dimensions and model parameter θ is sampled from the exponential family. 196 | 197 | It makes sense to assume that the x is sampled from a Bernoulli and here is the log-likelihood: 198 | 199 | ![equation](https://latex.codecogs.com/gif.latex?%5Cbegin%7Balign%7D%20L%28p%7Cx_1%2Cx_2...%2Cx_n%29%20%26%3D%20%5Cprod_%7Bi%20%3D%201%7D%5E%7Bn%7Dp%5E%7Bx_i%7D%281-p%29%5E%7B%281-x_i%29%7D%5Cnonumber%20%5C%5C%20%26%3D%5Csum_%7Bi%3D1%7D%5E%7Bn%7D%5Bx_i%5Clog%20%28p%29+%281-x_i%29%5Clog%20%281-p%29%5D%20%5Cnonumber%20%5Cend%7Balign%7D) 200 | 201 | Given a bunch of data for example,suppose output Y has (0/1): 202 | 203 | (92, 12), (23, 67), (67, 92), (98, 78), (18, 45), (6, 100) 204 | 205 | Final Result in class: 0, 0, 1, 1, 0, 0 206 | 207 | • If coefs are (-1, -1), LLH is -698 208 | • If coefs are (1, -1), LLH is -133 209 | • If coefs are (-1, 1), LLH is 7.4 210 | • If coefs are (1, 1), LLH is 394 211 | 212 | 213 | However this is not enough to get the loss function, logistic regreesion needs a ***sigmoid*** function to show the probability of y = 0/1,which is : 214 | 215 | ![equation](https://latex.codecogs.com/gif.latex?%5Cbegin%7Balign%7D%20P%28x_i%29%20%26%3D%20%5Cfrac%7B1%7D%7B1-e%5E%7B-y_i%7D%7D%5Cnonumber%20%5C%5C%20%26%3D%5Cfrac%7Be%5E%7B%5Ctheta%20_0+%5Ctheta_1%20x_1+...+%5Ctheta_d%20x_d%7D%7D%7B1+e%5E%7B%5Ctheta%20_0+%5Ctheta_1%20x_1+...+%5Ctheta_d%20x_d%7D%7D%20%5Cnonumber%20%5Cend%7Balign%7D) 216 | 217 | The parameter ω is related to X that is, assuming X is vector-valued and ω can be represent as : 218 | 219 | ![equation](https://latex.codecogs.com/gif.latex?%5Comega%20_i%20%3D%20%5Csum_%7Bj%20%3D%201%7D%5E%7Bd%7Dx_j%5E%7B%28i%29%7D%20%5Ctheta%20_j) 220 | 221 | where θ is regression coefficent and j is entity's jth dimension . 222 | 223 | Now its time to implement Log-likelihood in logistic regression, written as: 224 | 225 | ![equation](https://latex.codecogs.com/gif.latex?%5Cbegin%7Balign%7D%20L%28p%7Cx_1%2Cx_2...%2Cx_n%2C%20y_1%2Cy_2...%2Cy_n%29%20%26%3D%20%5Cfrac%7B1%7D%7Bn%7D%5Csum_%7Bi%3D1%7D%5E%7Bn%7D%5By_i%5Clog%20%5Cfrac%7Be%5E%7B%5Comega%20_i%7D%7D%7B1+e%5E%7B%5Comega_i%7D%7D+%281-y_i%29%5Clog%20%281-%5Cfrac%7Be%5E%7B%5Comega_i%7D%7D%7B1+e%5E%7B%5Comega_i%7D%7D%29%5D%5Cnonumber%20%5C%5C%20%26%3D%20%5Cfrac%7B1%7D%7Bn%7D%5Csum_%7Bi%3D1%7D%5E%7Bn%7D%5By_i%28%5Clog%20e%5E%7B%5Comega_i%7D%29-%20%5Clog%20%281+e%5E%7B%5Comega_i%7D%29%5D%5Cnonumber%20%5C%5C%20%26%3D%5Cfrac%7B1%7D%7Bn%7D%5Csum_%7Bi%3D1%7D%5E%7Bn%7D%5By_i%5Comega_i-%5Clog%20%281+e%5E%7B%5Comega_i%7D%29%5D%5Cnonumber%20%5Cend%7Balign%7D) 226 | 227 | 228 | Now calculate loss function.As gradient descent need to minimize loss function,the loss function should be negative LLH: 229 | 230 | ![equation](https://latex.codecogs.com/gif.latex?%5Cbegin%7Balign%7D%20loss%20function%20%26%3D%20%5Cfrac%7B1%7D%7Bn%7D%5Csum_%7Bi%3D1%7D%5E%7Bn%7D%5By_i%5Comega_i-%5Clog%20%281+e%5E%7B%5Comega_i%7D%29%5D%5Cnonumber%20%5C%5C%20%26%3D%20%5Cfrac%7B1%7D%7Bn%7D%5Csum_%7Bi%3D1%7D%5E%7Bn%7D%5B-%20y_i%5Comega_i%20+%20%5Clog%20%281+e%5E%7B%5Comega_i%7D%29%5D%5Cnonumber%20%5Cend%7Balign%7D) 231 | 232 | Appling regularization (l2 norm): 233 | 234 | ![equation](https://latex.codecogs.com/gif.latex?%5Cbegin%7Balign%7D%20loss%20function%20%26%3D%20%5Cfrac%7B1%7D%7Bn%7D%5Csum_%7Bi%3D1%7D%5E%7Bn%7D%5B-%20y_i%5Comega_i%20+%20%5Clog%20%281+e%5E%7B%5Comega_i%7D%29%5D%5Cnonumber%20%5C%5C%20%26%3D%20%5Cfrac%7B1%7D%7Bn%7D%5Csum_%7Bi%3D1%7D%5E%7Bn%7D%5B-%20y_i%5Comega_i%20+%20%5Clog%20%281+e%5E%7B%5Comega_i%7D%29%5D%20+%20%5Clambda%20%5Csum_%7Bi%3D1%7D%5E%7Bj%7D%5Ctheta%20_i%20%5E%7B2%7D%5Cnonumber%20%5Cend%7Balign%7D) 235 | 236 | where j is entity's jth dimension. 237 | 238 | 239 | #### How to calculate gradient 240 | 241 | Suppose θj is jth partial derivative : 242 | 243 | ![equation](https://latex.codecogs.com/gif.latex?%5Cbegin%7Balign%7D%20%5CDelta%20l%28%5Ctheta%29%20%3D%20%5Cbegin%7Bbmatrix%7D%5Cnonumber%20%5Cfrac%7B%5Cpartial%20l%28%5Ctheta%29%7D%7B%5Cpartial%20%5Ctheta%20_0%7D%5C%5C%20%5Cfrac%7B%5Cpartial%20l%28%5Ctheta%29%7D%7B%5Cpartial%20%5Ctheta%20_1%7D%5C%5C%20.%5C%5C%20.%5C%5C%20%5Cfrac%7B%5Cpartial%20l%28%5Ctheta%29%7D%7B%5Cpartial%20%5Ctheta%20_d%7D%5C%5C%20%5Cend%7Bbmatrix%7D%20%5Cend%7Balign%7D) 244 | 245 | Since it has mutiple dimensions,we compute partial derivatives: 246 | 247 | ![equation](https://latex.codecogs.com/gif.latex?%5Cbegin%7Balign%7D%20%5Cfrac%7B%5Cpartial%20l%7D%7B%5Cpartial%20%5Ctheta%20_1%7D%20%3D%20%5Cfrac%7B1%7D%7Bn%7D%5Csum_%7Bi%3D1%7D%5E%7Bn%7D%5By_ix_1%5E%7B%28i%29%7D+x_1%5E%7B%28i%29%7D%5Cfrac%7Be%5E%7B%5Comega_i%7D%7D%7B1+e%5E%7B%5Comega_i%7D%7D%5D+2%5Clambda%20%5Comega%20_1%20%5Cnonumber%20%5C%5C%20%5Cfrac%7B%5Cpartial%20l%7D%7B%5Cpartial%20%5Ctheta%20_2%7D%20%3D%20%5Cfrac%7B1%7D%7Bn%7D%5Csum_%7Bi%3D1%7D%5E%7Bn%7D%5By_ix_2%5E%7B%28i%29%7D+x_2%5E%7B%28i%29%7D%5Cfrac%7Be%5E%7B%5Comega_i%7D%7D%7B1+e%5E%7B%5Comega_i%7D%7D%5D+2%5Clambda%20%5Comega%20_2%20%5Cnonumber%20%5C%5C%20...%20%5Cnonumber%20%5C%5C%20%5Cfrac%7B%5Cpartial%20l%7D%7B%5Cpartial%20%5Ctheta%20_d%7D%20%3D%20%5Cfrac%7B1%7D%7Bn%7D%5Csum_%7Bi%3D1%7D%5E%7Bn%7D%5By_ix_d%5E%7B%28i%29%7D+x_d%5E%7B%28i%29%7D%5Cfrac%7Be%5E%7B%5Comega_i%7D%7D%7B1+e%5E%7B%5Comega_i%7D%7D%5D+2%5Clambda%20%5Comega%20_d%20%5Cnonumber%20%5Cend%7Balign%7D) 248 | 249 | 250 | #### Gradient Descent pseudocode in Pyspark** 251 | 252 | ![pseudocode1](https://github.com/gnayoaixgnaw/machine_learning_project/blob/main/image/pseudocode2.png) 253 | 254 | 255 | #### Implement code via Pyspark 256 | 257 | ***Check [here](https://github.com/gnayoaixgnaw/Big_Data_Analytics/tree/main/assignment4)*** 258 | 259 | 260 | ### multi-class logistic regression 261 | 262 | In binary-class lr model, we use sigmoid function to map samples to (0,1),but in more cases, we need multi-class classfication, so we use ***softmax*** function to map samples to multiple (0,1). 263 | 264 | Softmax can be written as a hypothesis function : 265 | 266 | ![equation](https://latex.codecogs.com/gif.latex?h_%5Ctheta%20%28x%5E%7B%28i%29%7D%29%20%3D%20%5Cbegin%7Bbmatrix%7D%20p%28y%5E%7B%28i%29%7D%20%3D%201%7Cx%5E%7B%28i%29%7D%3B%5Ctheta%20%29%5C%5C%20p%28y%5E%7B%28i%29%7D%20%3D%202%7Cx%5E%7B%28i%29%7D%3B%5Ctheta%20%29%5C%5C%20...%5C%5C%20p%28y%5E%7B%28i%29%7D%20%3D%20k%7Cx%5E%7B%28i%29%7D%3B%5Ctheta%20%29%20%5Cend%7Bbmatrix%7D%20%3D%20%5Cfrac%7B1%7D%7B%5Csum_%7Bj%3D1%7D%5E%7Bk%7De%5E%7B%5Ctheta_j%20%5E%7BT%7D%20x%5E%7B%28i%29%7D%7D%7D%5Cbegin%7Bbmatrix%7D%20e%5E%7B%5Ctheta_1%20%5E%7BT%7D%20x%5E%7B%28i%29%7D%7D%5C%5C%20e%5E%7B%5Ctheta_2%20%5E%7BT%7D%20x%5E%7B%28i%29%7D%7D%5C%5C%20...%5C%5C%20e%5E%7B%5Ctheta_k%20%5E%7BT%7D%20x%5E%7B%28i%29%7D%7D%20%5Cend%7Bbmatrix%7D) 267 | 268 | where k is the total number of classes, i is ith entity. 269 | 270 | then we can get the loss function,which is also be called log-likelihood cost: 271 | 272 | ![equation](https://latex.codecogs.com/gif.latex?J%28%5Ctheta%20%29%20%3D%20-%5Cfrac%7B1%7D%7Bn%7D%5B%5Csum_%7Bi%3D1%7D%5E%7Bn%7D%5Csum_%7Bj%3D1%7D%5E%7Bk%7D1%5Cleft%20%5C%7B%20y%5E%7B%28i%29%7D%20%3D%20j%20%5Cright%20%5C%7D%5Clog%20%5Cfrac%7Be%5E%7B%5Ctheta%20_j%5ET%7Dx%5E%7B%28i%29%7D%7D%7B%5Csum_%7Bl%3D1%7D%5E%7Bk%7De%5E%7B%5Ctheta%20_l%5ET%7Dx%5E%7B%28i%29%7D%7D%5D) 273 | 274 | where 1{expression} is a function that if expression in {} is true then 1{expression} = 1 ,else 0. 275 | 276 | then rearrange it, and add l2 norm : 277 | 278 | ![equation](https://latex.codecogs.com/gif.latex?J%28%5Ctheta%20%29%20%3D%20-%5Cfrac%7B1%7D%7Bn%7D%5Csum_%7Bi%3D1%7D%5E%7Bn%7D%5Csum_%7Bj%3D1%7D%5E%7Bk%7D1%5Cleft%20%5C%7B%20y%5E%7B%28i%29%7D%20%3D%20j%20%5Cright%20%5C%7D%5B%5Clog%20e%5E%7B%5Ctheta%20_j%5ET%7Dx%5E%7B%28i%29%7D-%20%5Clog%20%5Csum_%7Bl%3D1%7D%5E%7Bk%7De%5E%7B%5Ctheta%20_l%5ET%7Dx%5E%7B%28i%29%7D%5D%5Clambda%20%5Csum_%7Bi%3D1%7D%5E%7Bk%7D%5Ctheta%20_i%5E%7B2%7D) 279 | 280 | #### How to calculate gradient 281 | 282 | Suppose θj is jth partial derivative : 283 | 284 | ![equation](https://latex.codecogs.com/gif.latex?%5Cbegin%7Balign%7D%20%5Cfrac%7B%5Cpartial%20J%28%5Ctheta%20%29%7D%7B%5Cpartial%20%5Ctheta%20_j%7D%20%26%20%3D%20-%5Cfrac%7B1%7D%7Bn%7D%5Csum_%7Bi%3D1%7D%5E%7Bn%7D%20%5Bx%5E%7B%28i%29%7D%20%281%5Cleft%20%5C%7B%20y%5E%7B%28i%29%7D%20%3D%20j%20%5Cright%20%5C%7D%20-%20%5Cfrac%7Be%5E%7B%5Ctheta%20_j%5ET%7Dx%5E%7B%28i%29%7D%7D%7B%5Csum_%7Bl%3D1%7D%5E%7Bk%7De%5E%7B%5Ctheta%20_l%5ET%7Dx%5E%7B%28i%29%7D%7D%29%5D%20+%202%5Clambda%20%5Ctheta%20_j%5Cnonumber%5C%5C%20%26%3D%20-%5Cfrac%7B1%7D%7Bn%7D%5Csum_%7Bi%3D1%7D%5E%7Bn%7D%5B%20x%5E%7B%28i%29%7D%20%281%5Cleft%20%5C%7B%20y%5E%7B%28i%29%7D%20%3D%20j%20%5Cright%20%5C%7D%20-%20p%28y%5E%7B%28i%29%7D%3Dj%7Cx%5E%7B%28i%29%7D%3B%5Ctheta%20%29%29%5D%20+%202%5Clambda%20%5Ctheta%20_j%20%5Cnonumber%20%5Cend%7Balign%7D) 285 | 286 | 287 | #### Gradient Descent pseudocode in Pyspark** 288 | 289 | ![pseudocode1](https://github.com/gnayoaixgnaw/machine_learning_project/blob/main/image/pseudocode2.png) 290 | 291 | 292 | #### Implement code via Pyspark 293 | 294 | ***Check [here]()*** 295 | 296 | 297 | 298 | ## Support victor machine 299 | 300 | Traditional svm is binary-class svm, and there is also multi-class svm. 301 | 302 | 303 | ### binary class svm 304 | 305 | Suppose there is a dataset that is linearly separable, it is possible to put a strip between two classes.The points that keep strip from expending are 'support vector'. 306 | 307 | So basiclly, all points x in any line or plane or hyperplane can be discribed as a vevtor with distance b: 308 | 309 | ![equation](https://latex.codecogs.com/gif.latex?%5Cvec%7Bw%7D%5Ccdot%20x_0%20+%20b%20%3D%200) 310 | 311 | Now here is a point x and we need to caluculate the distance (which can be discribed as y) between this point and plane: 312 | 313 | ![equation](https://latex.codecogs.com/gif.latex?y%20%3D%20%5Cvec%7Bw%7D%5Ccdot%28x-x_0%29%20%3D%20%5Cvec%7Bw%7D%5Ccdot%20x%20-%20%5Cvec%7Bw%7D%5Ccdot%20x_0%20%3D%20%5Cvec%7Bw%7D%5Ccdot%20x+b) 314 | 315 | Notice y should be -1 or 1 to determine which of the sides the point x is. 316 | 317 | This is because in basic svm, it choose two planes: 318 | 319 | ![equation](https://latex.codecogs.com/gif.latex?%5Cbegin%7Balign%7D%20%5Cvec%7Bw%7D%5Ccdot%20x+b%20%3D%201%20%5Cnonumber%20%5C%5C%20%5Cvec%7Bw%7D%5Ccdot%20x+b%20%3D%20-1%5Cnonumber%20%5Cend%7Balign%7D) 320 | 321 | where y > 1 then x is '+'(or x is positive sample) 322 | where y < -1 then x is '-'(or x is negative sample) 323 | 324 | ![image](https://pic4.zhimg.com/v2-197913c461c1953c30b804b4a7eddfcc_1440w.jpg?source=172ae18b) 325 | 326 | The distance between two planes are ![equation](https://latex.codecogs.com/gif.latex?%5Cfrac%7B2%7D%7B%5Cleft%20%5C%7C%20%5Cvec%7Bw%7D%20%5Cright%20%5C%7C%7D),so we need to maximize this distance to get optimal solution. 327 | 328 | #### loss function 329 | 330 | First, define a normal loss function: 331 | 332 | ![equation](https://latex.codecogs.com/gif.latex?l%20%3D%20%5Csum_%7Bi%20%3D%201%7D%5E%7Bn%7Dl%28y_i%20%5E%7Bpred%7D%2C%20y_i%5E%7Btrue%7D%29) 333 | 334 | when this loss function is hinge loss, it is exactly svm's loss function: 335 | 336 | ![equation](https://latex.codecogs.com/gif.latex?l%20%3D%20%5Cfrac%7B1%7D%7Bn%7D%5Csum_%7Bi%3D1%7D%5E%7Bn%7Dmax%280%2C%20y_i%5E%7Bpred%7D*y_i%5E%7Btrue%7D%29%20%3D%20%5Cfrac%7B1%7D%7Bn%7D%5Csum_%7Bi%3D1%7D%5E%7Bn%7Dmax%280%2C%201%20-%5Cvec%7Bw%7D%5Ccdot%20x_i*y_i%29) 337 | 338 | then add l2 norm, the final loss function is : 339 | 340 | ![equation](https://latex.codecogs.com/gif.latex?l%20%3D%20%5Cfrac%7B1%7D%7Bn%7D%5Csum_%7Bi%3D1%7D%5E%7Bn%7Dmax%280%2C%201%20-%5Cvec%7Bw%7D%5Ccdot%20x_i*y_i%29%20+%20%5Clambda%20%5Cleft%20%5C%7C%20%5Cvec%7Bw%7D%20%5Cright%20%5C%7C%5E%7B2%7D) 341 | 342 | #### How to calculate gradient 343 | 344 | We can use Chain rule to compute the derivative: 345 | 346 | ![equation](https://latex.codecogs.com/gif.latex?%5Cfrac%7B%5Cpartial%20l%7D%7B%5Cpartial%20%5Cvec%7Bw%7D%7D%20%3D%20%5Cfrac%7B%5Cpartial%20l%7D%7B%5Cpartial%20%28%5Cvec%7Bw%7D%5Ccdot%20x%29%7D%5Cfrac%7B%5Cpartial%20%28%5Cvec%7Bw%7D%5Ccdot%20x%29%7D%7B%5Cpartial%20%28%5Cvec%7Bw%7D%29%7D%20+%202%5Clambda%20%5Cvec%7Bw%7D) 347 | 348 | then calculate derivatives for two parts: 349 | 350 | ![equation](https://latex.codecogs.com/gif.latex?%5Cbegin%7Balign%7D%20%5Cfrac%7B%5Cpartial%20l%7D%7B%5Cpartial%20%28%5Cvec%7Bw%7D%5Ccdot%20x%29%7D%20%3D%20%5Cfrac%7B%5Cpartial%20%5Csum_%7Bi%3D1%7D%5E%7Bn%7Dmax%280%2C%201-%20%5Cvec%7Bw%7D%5Ccdot%20x_i*y_i%20%5E%7Btrue%7D%29%7D%7B%5Cpartial%20%28%5Cvec%7Bw%7D%5Ccdot%20x_i%29%7D%20%5C%5C%20%5Cfrac%7B%5Cpartial%20%28%5Cvec%7Bw%7D%5Ccdot%20x%29%7D%7B%5Cpartial%20%5Cvec%7Bw%7D%7D%20%3D%20%5Csum_%7Bi%3D1%7D%5E%7Bn%7Dx_i%20%5Cend%7Balign%7D) 351 | 352 | In conclusion, the final gradient is : 353 | 354 | (1)if ![equation](https://latex.codecogs.com/gif.latex?%281-%20%5Cvec%7Bw%7D%5Ccdot%20x_i*y_i%20%5E%7Btrue%7D%29%20%3C%200): 355 | 356 | below fomula(1) = 0, which means the derivative = ![equation](https://latex.codecogs.com/gif.latex?0+%202%5Clambda%20%5Cvec%7Bw%7D) 357 | 358 | (2)if ![equation](https://latex.codecogs.com/gif.latex?%281-%20%5Cvec%7Bw%7D%5Ccdot%20x_i*y_i%20%5E%7Btrue%7D%29%20%3E%20%3D%200): 359 | 360 | below fomula(1) = ![equation](https://latex.codecogs.com/gif.latex?-%20y_i%20%5E%7Btrue%7D), which means the derivative = ![equation](https://latex.codecogs.com/gif.latex?%5Csum_%7Bi%20%3D%200%7D%5E%7Bn%7D-%20y_i%20%5E%7Btrue%7Dx_i+2%5Clambda%20%5Cvec%7Bw%7D) 361 | 362 | then the final derivatives for ***batch of data*** can be written as: 363 | 364 | ![equation](https://latex.codecogs.com/gif.latex?%5Cfrac%7B1%7D%7Bn%7D%20%5Csum_%7Bi%20%3D%200%7D%5E%7Bn%7D%20%5Cleft%5C%7B%5Cbegin%7Bmatrix%7D%200%26%20%2Cif%20%281-%20y_i%20%5E%7Btrue%7D*%5Cvec%7Bw%7D%5Ccdot%20x_i%29%20%3C0%5C%5C%20-%20y_i%20%5E%7Btrue%7Dx_i%26%20%2Cif%20%281-%20y_i%20%5E%7Btrue%7D*%5Cvec%7Bw%7D%5Ccdot%20x_i%29%20%3E%3D0%20%5Cend%7Bmatrix%7D%5Cright.+%202%5Clambda%20%5Cvec%7Bw%7D) 365 | 366 | 367 | #### Implement code via Pyspark 368 | 369 | ***Check [here]()*** 370 | 371 | 372 | 373 | ### multiple class svm 374 | 375 | Different from binary-class svm, multi-class svm's loss function requires the score on the correct class always be ***Δ(a boundary value)*** higher than scores on incorrect classes. 376 | 377 | Suppose there is a dataset, the ***ith entity xi*** contains its features(has ***d*** features represent as vector) and class ***yi***, then given svm model f(xi,w) to calculate the scores(as vector) in all classes,here we use **si** to represent this vector. 378 | 379 | #### loss function 380 | 381 | According to the definition of multi-class svm, we can get ith entity xi's loss fucntion: 382 | 383 | ![equation](https://latex.codecogs.com/gif.latex?l_i%20%3D%20%5Csum_%7Bi%5Cneq%20y_i%7D%5E%7Bd%7Dmax%280%2Cs_j%20-%20s_y_i+%5CDelta%20%29) 384 | 385 | For example: 386 | 387 | Suppose there are 3 classes, for the ith sample xi, we get scores = [12,-7,10],where the first class(yi) is correct.Then we make Δ=9, applying the below loss function, we can calculate the loss of xi: 388 | 389 | ![equation](https://latex.codecogs.com/gif.latex?l_i%20%3D%20max%280%2C%20-7-12+9%29+max%280%2C%2010-12+9%29) 390 | 391 | As ***w*** is a vector in loss function, we expend loss function: 392 | 393 | ![equation](https://latex.codecogs.com/gif.latex?l_i%20%3D%20%5Csum_%7Bi%5Cneq%20y_i%7D%5E%7Bd%7Dmax%280%2C%5Cvec%7Bw_j%7D%5E%7BT%7Dx_i%20-%20%5Cvec%7Bw_y_i%7D%5E%7BT%7Dx_i%20+%5CDelta%20%29) 394 | 395 | Then add l2 norm: 396 | 397 | ![equation](https://latex.codecogs.com/gif.latex?l_i%20%3D%20%5Csum_%7Bi%5Cneq%20y_i%7D%5E%7Bd%7Dmax%280%2C%5Cvec%7Bw_j%7D%5E%7BT%7Dx_i%20-%20%5Cvec%7Bw_y_i%7D%5E%7BT%7Dx_i%20+%5CDelta%29%20+%20%5Clambda%20%5Cleft%20%5C%7C%20%5Cvec%7Bw%7D%20%5Cright%20%5C%7C%5E2) 398 | 399 | #### How to calculate gradient 400 | 401 | As ***yi*** and ***j*** are different, we calculate ***wyi***'s derivative for ith entity xi first: 402 | 403 | ![equation](https://latex.codecogs.com/gif.latex?%5Cbegin%7Balign%7D%20%5Cfrac%7B%5Cpartial%20l_i%7D%7B%5Cpartial%20%5Cvec%7Bw_y_i%7D%7D%20%26%3D%20%5Cfrac%7B%5Cpartial%20%5Csum_%7Bi%5Cneq%20y_i%7D%5E%7Bd%7Dmax%280%2C%5Cvec%7Bw_j%7D%5E%7BT%7Dx_i%20-%20%5Cvec%7Bw_y_i%7D%5E%7BT%7Dx_i%20+%5CDelta%29%20+%20%5Clambda%20%5Cleft%20%5C%7C%20%5Cvec%7Bw%7D%20%5Cright%20%5C%7C%5E2%7D%7B%5Cpartial%20%5Cvec%7Bw_y_i%7D%7D%20%5Cnonumber%5C%5C%20%26%3D%20%5Cleft%5C%7B%5Cbegin%7Bmatrix%7D%20-x_i+%202%5Clambda%20%5Cvec%7Bw_y_i%7D%26%2C%20if%20%28%5Cvec%7Bw_j%7D%5E%7BT%7Dx_i%20-%20%5Cvec%7Bw_y_i%7D%5E%7BT%7Dx_i%20+%5CDelta%29%20%3E%3D0%5C%5C%200+2%5Clambda%20%5Cvec%7Bw_y_i%7D%26%20%2Cif%20%28%5Cvec%7Bw_j%7D%5E%7BT%7Dx_i%20-%20%5Cvec%7Bw_y_i%7D%5E%7BT%7Dx_i%20+%5CDelta%29%20%3C0%20%5Cend%7Bmatrix%7D%5Cright.%20%5Cnonumber%20%5Cend%7Balign%7D) 404 | 405 | then calculate ***wj***'s derivative for ith entity xi: 406 | 407 | ![equation](https://latex.codecogs.com/gif.latex?%5Cbegin%7Balign%7D%20%5Cfrac%7B%5Cpartial%20l_i%7D%7B%5Cpartial%20%5Cvec%7Bw_j%7D%7D%20%26%3D%20%5Cfrac%7B%5Cpartial%20%5Csum_%7Bi%5Cneq%20y_i%7D%5E%7Bd%7Dmax%280%2C%5Cvec%7Bw_j%7D%5E%7BT%7Dx_i%20-%20%5Cvec%7Bw_y_i%7D%5E%7BT%7Dx_i%20+%5CDelta%29%20+%20%5Clambda%20%5Cleft%20%5C%7C%20%5Cvec%7Bw%7D%20%5Cright%20%5C%7C%5E2%7D%7B%5Cpartial%20%5Cvec%7Bw_j%7D%7D%20%5Cnonumber%5C%5C%20%26%3D%20%5Cleft%5C%7B%5Cbegin%7Bmatrix%7D%20x_i+%202%5Clambda%20%5Cvec%7Bw_j%7D%26%2C%20if%20%28%5Cvec%7Bw_j%7D%5E%7BT%7Dx_i%20-%20%5Cvec%7Bw_y_i%7D%5E%7BT%7Dx_i%20+%5CDelta%29%20%3E%3D0%5C%5C%200+2%5Clambda%20%5Cvec%7Bw_j%7D%26%20%2Cif%20%28%5Cvec%7Bw_j%7D%5E%7BT%7Dx_i%20-%20%5Cvec%7Bw_y_i%7D%5E%7BT%7Dx_i%20+%5CDelta%29%20%3C0%20%5Cend%7Bmatrix%7D%5Cright.%20%5Cnonumber%20%5Cend%7Balign%7D) 408 | 409 | The gradient of vector ***wi*** contians ***wj***(from 1 to d) and each ***wj***'s value depends on the value of ![equation](https://latex.codecogs.com/gif.latex?%28%5Cvec%7Bw_j%7D%5E%7BT%7Dx_i%20-%20%5Cvec%7Bw_y_i%7D%5E%7BT%7Dx_i%20+%5CDelta%29): 410 | 411 | ![equation](https://latex.codecogs.com/gif.latex?%5CDelta%20_i%20%3D%20%5Bw_1%5E%7Bi%7D%20%3D%20%5Cleft%5C%7B%5Cbegin%7Bmatrix%7D%20x_i%5C%5C%20-x_i%5C%5C%200%20%5Cend%7Bmatrix%7D%5Cright....w_d%5E%7Bi%7D%20%3D%20%5Cleft%5C%7B%5Cbegin%7Bmatrix%7D%20x_i%5C%5C%20-x_i%5C%5C%200%20%5Cend%7Bmatrix%7D%5Cright.%5D) 412 | 413 | So the gradient for ***batch of data*** can be written as: 414 | 415 | ![equation](https://latex.codecogs.com/gif.latex?%5Cfrac%7B%5Cpartial%20l%7D%7B%5Cpartial%20%5Cvec%7Bw%7D%7D%20%3D%20%5Cfrac%7B1%7D%7Bn%7D%5Csum_%7Bi%3D1%7D%5E%7Bn%7D%5CDelta%20%5Cvec%7Bw_i%7D%20+%202%5Clambda%20%5Cvec%7Bw%7D) 416 | 417 | #### Implement code via Pyspark 418 | 419 | ***Check [here]()*** 420 | 421 | 422 | ## Naive Bayes 423 | 424 | Naive Bayes is based on Bayes' theorem, which is : 425 | 426 | ![equation](https://latex.codecogs.com/gif.latex?P%28Y_k%7CX%29%20%3D%20%5Cfrac%7BP%28X%7CY_k%29P%28Y_k%29%7D%7B%5Csum_%7Bk%7D%5E%7B%7DP%28X%7CY%20%3D%20Y_k%29P%28Y_k%29%7D) 427 | 428 | we suppose dataset has m entities, each entitiy has n dimensions.There are k classes, define as : 429 | 430 | 431 | ![equation](https://latex.codecogs.com/gif.latex?%28x_1%5E%7B1%7D%2Cx_2%5E%7B1%7D%2C...x_n%5E%7B1%7D%2C%20y_1%29%2C%28x_1%5E%7B2%7D%2Cx_2%5E%7B2%7D%2C...x_n%5E%7B2%7D%2C%20y_2%29%2C...%2C%28x_1%5E%7Bm%7D%2Cx_2%5E%7Bm%7D%2C...x_n%5E%7Bm%7D%20%2C%20y_m%29) 432 | 433 | We can get p(X,Y)'s joint probability via Bayes' theorem: 434 | 435 | ![equation](https://latex.codecogs.com/gif.latex?%5Cbegin%7Balign%7D%20P%28X%2CY%3DC_k%29%20%26%3D%20P%28Y%20%3D%20C_k%29P%28X%3Dx%7CY%20%3D%20C_k%29%5Cnonumber%20%5C%5C%20%26%20%3D%20P%28Y%20%3D%20C_k%29P%28X_1%3Dx_1%2CX_2%3Dx_2%2C...X_n%3Dx_n%7CY%20%3D%20C_k%29%5Cnonumber%20%5Cend%7Balign%7D) 436 | 437 | Suppose n dimensions in entity are independent of each other : 438 | 439 | ![equation](https://latex.codecogs.com/gif.latex?P%28X_1%3Dx_1%2CX_2%3Dx_2%2C...X_n%3Dx_n%7CY%20%3D%20C_k%29%5C%5C%20%3D%20P%28X_1%3Dx_1%7CY%20%3D%20C_k%29P%28X_2%3Dx_2%7CY%20%3D%20C_k%29...P%28X_n%3Dx_n%7CY%20%3D%20C_k%29) 440 | 441 | Notice some dimentions are discrete type,while others are continuous type, 442 | 443 | suppose ***Sck*** is subset where item's class is Ck, ***S k,xi*** is subset of Sk where item's i dimension is xi. 444 | 445 | ***discrete type*** 446 | 447 | ![eqution](https://latex.codecogs.com/gif.latex?P%28x_i%7CY%20%3D%20C_k%29%20%3D%20%5Cfrac%7B%7CS_%7Bk%2Cx_i%7D%7C%7D%7B%7CS_k%7C%7D) 448 | 449 | ***continuous type*** 450 | 451 | Suppose Sk,xi are subject to Gaussian distribution: 452 | 453 | ![equation](https://latex.codecogs.com/gif.latex?P%28x_i%7CY%20%3D%20C_k%29%20%5Csim%20N%28%5Cmu%20_%7BC_k%2Ci%7D%2C%5Csigma_%7BC_k%2Ci%7D%5E%7B2%7D%29) 454 | 455 | then we can get: 456 | 457 | ![equation](https://latex.codecogs.com/gif.latex?P%28x_i%7CY%20%3D%20C_k%29%20%3D%20%5Cfrac%7B1%7D%7B%5Csqrt%7B2%5Cpi%20%5Csigma%20_%7BC_k%2Ci%7D%7D%7De%5E%7B-%5Cfrac%7B%28x-%5Cmu%20_%7BC_k%2Ci%7D%29%5E2%7D%7B2%5Csigma%20_%7BC_k%2Ci%7D%5E2%7D%7D) 458 | 459 | After comparing p(X,Y) in all classes, we can get the result class which has the highest p(X,Y). 460 | 461 | 462 | ### Laplacian correction 463 | 464 | For example, if there is no sample(which belongs to Ck and ith dimension is Xa), then ![equation](https://latex.codecogs.com/gif.latex?P%28X%20%3D%20X_a%7CY%20%3D%20C_k%29%20%3D%200). 465 | 466 | Becasue of the continued product, ![equation](https://latex.codecogs.com/gif.latex?P%28X%2CY%20%3D%20C_k%29%20%3D%200) 467 | 468 | use laplacian correction: 469 | 470 | ![equation](https://latex.codecogs.com/gif.latex?P%28Y%20%3D%20C_k%29%20%3D%20%5Cfrac%7B%7CS_%7BC_k%7D%7C+1%7D%7B%7CS%7C+N%7D) 471 | 472 | where |Sck| is the number of subset Sck(where items are class ck) , |S| is number of whole dataset, N is the number of classes 473 | 474 | ![equation](https://latex.codecogs.com/gif.latex?P%28x_i%7CY%20%3D%20C_k%29%20%3D%20%5Cfrac%7B%7CS_%7BC_%7Bk%2Cx_i%7D%7D%7C+1%7D%7B%7CS_%7BC_k%7D%7C+N_i%7D) 475 | 476 | where |S k,xi| is the number of the subset of Sck where item's i dimension is xi, Ni is the number of possible value of ith attribute. 477 | 478 | 479 | ### GaussianNB,MultinomialNB, BernoulliNB 480 | 481 | MultinomialNB is naive Bayes with a priori Gaussian distribution, multinomialNB is naive Bayes with a priori multinomial distribution and BernoullinB is naive Bayes with a priori Bernoulli distribution. 482 | 483 | These three classes are applicable to different classification scenarios. In general, if the distribution of sample features is mostly continuous values, Gaussiannb is better. MultinomialNb is appropriate if the most of the sample features are multivariate discrete values. If the sample features are binary discrete values or very sparse multivariate discrete values, Bernoullinb should be used. 484 | 485 | 486 | 487 | 488 | 489 | ## Optimizer 490 | 491 | 492 | Variations of Gradient Descent depends on size of data that be used in each iteration: 493 | 494 | • Full Batch Gradient Descent (Using the whole data set (size n)) 495 | • Stochastic Gradient Descent (SGD) (Using one sample per iteration (size 1)) 496 | • Mini Batch Gradient Descent (Using a mini batch of data (size m < n)) 497 | 498 | **BGD** has a disadvantage: In each iteration, as it calculates gradients from whole dataset, this process will take lots of time.BGD can't overcome local minimum problem, because we can not add new data to trainning dataset.In other word, when function comes to local minimum point, full batch gradient will be 0, which means optimization process will stop. 499 | 500 | **SGD** is always used in online situation for its speed.But since SGD uses only one sample in each iteration, the gradient can be affacted by noisy point, causeing a fact that function may not converge to optimal solution. 501 | 502 | **MSGD** finds a trade-off between SGD and BGD.Different from BGD and SGD, MSGD only pick a small batch-size of data in each iteration, which not only minimized the impact from noisy point, but also reduced trainning time and increased the accuracy. 503 | 504 | ### Learning rate 505 | 506 | The learning rate is a vital parameter in gradient descent as learning rate is responsible for convergence, if lr is small, convergent speed will be slow, on the contrary,when lr is large, function will converge very fast. 507 | 508 | compare different learning rate: 509 | 510 | ![image](https://gimg2.baidu.com/image_search/src=http%3A%2F%2Fimg2018.cnblogs.com%2Fblog%2F1217276%2F201810%2F1217276-20181007182807634-196732269.png&refer=http%3A%2F%2Fimg2018.cnblogs.com&app=2002&size=f9999,10000&q=a80&n=0&g=0n&fmt=jpeg?sec=1617069776&t=b25621a89b513f8b765ac8f116bee051) 511 | 512 | 513 | So how to find a proper learning rate? If we set lr a large value, the function will converge very fast at beginning but may miss the optimal solution,but if we set a small value, it will cost too much time to converge. As iterations going on, we hope that learning rate becomes smaller.Because when function close to optimal solution, the changing step should be small to find best solution.So we need to gradually change learning rate. 514 | 515 | 516 | >Here is a very simple method, which name is ***Bold Driver*** to change learning rate dynamicly: 517 | > 518 | >At each iteration, compute the cost l(θ0,θ1...) 519 | > 520 | >Better than last time? 521 | > 522 | >If cost decreases, increase learning rate 523 | >l = 1.05 * l 524 | > 525 | >Worse than last time? 526 | > 527 | >l = 0.5 * l 528 | >If cost increases, decrease rate 529 | 530 | A better method is ***Time-Based Decay*** .The mathematical form of time-based decay is lr = lr0/(1+kt) 531 | 532 | where lr, k are hyperparameters and t is the iteration number. 533 | 534 | Those graphs illustrate the advantages of Time-Based Decay lr vs constant lr: 535 | 536 | ***constant lr*** 537 | 538 | ![image](https://miro.medium.com/max/864/1*Lv7-jMtHOoucryv9mUtFGg.jpeg) 539 | 540 | 541 | ***Time-Based Decay lr*** 542 | 543 | ![image](https://miro.medium.com/max/864/1*YpzU0MkpNaZ8f6cGvqex7g.jpeg) 544 | 545 | Also, we know that the weights for each coefficent is different, which means gradients of some coefficents are large while some are little.So in traditional SGD, changes of coefficents are not synchronous.So we need to balance the coefficents when doing gradient descent. 546 | 547 | To deal with this issue, we can use **SGDM**, **Adagrad(adaptive gradient algorithm)**, **RMSProp**, **Adam** 548 | 549 | ### SGDM 550 | 551 | SGDM is SGD with momentum.It implement momentum to gradient: 552 | 553 | ![equation](https://latex.codecogs.com/gif.latex?%5Cbegin%7Balign%7D%20m_j%20%5Cleftarrow%20%5Clambda%20m_%7Bj%7D%20+%20%5Ceta%20%5Cfrac%7B%5Cpartial%20l%7D%7B%5Cpartial%20%5Ctheta%20_j%7D%5Cnonumber%20%5C%5C%20%5Ctheta%20_j%20%5Cleftarrow%20%5Ctheta%20_j%20+%20m_j%5Cnonumber%20%5Cend%7Balign%7D) 554 | 555 | where m0 = 0, λ is momentum's coefficent, η is learing rate. 556 | 557 | 558 | 559 | visualization: 560 | 561 | ![image](https://imgconvert.csdnimg.cn/aHR0cHM6Ly9pbnRyYW5ldHByb3h5LmFsaXBheS5jb20vc2t5bGFyay9sYXJrLzAvMjAyMC9wbmcvOTMwNC8xNTk4NTIxNDQ4NTQwLWViMjEwNTQ5LWNiOTMtNDIxMC05NDJmLTg2Mzk0Y2Y4Njk5ZC5wbmc?x-oss-process=image/format,png#align=left&display=inline&height=306&margin=%5Bobject%20Object%5D&name=image.png&originHeight=682&originWidth=1080&size=488854&status=done&style=none&width=484wZw#pic_center) 562 | 563 | 564 | ### Adagrad 565 | 566 | Adagrad will give each coefficent a proper learning rate: 567 | 568 | ![equation](https://latex.codecogs.com/gif.latex?%5Cbegin%7Balign%7D%20h_j%20%5Cleftarrow%20h_j%20+%28%5Cfrac%7B%5Cpartial%20l%7D%7B%5Cpartial%20%5Ctheta%20_j%7D%29%5E%7B2%7D%20%5C%5C%20%5Ctheta%20_j%20%5Cleftarrow%20%5Ctheta%20_j%20-%20%5Ceta%20%5Cfrac%7B1%7D%7B%5Csqrt%7Bh_j%7D%7D%5Cfrac%7B%5Cpartial%20l%7D%7B%5Cpartial%20%5Ctheta%20_j%7D%20%5Cend%7Balign%7D) 569 | 570 | θ : coefficents 571 | 572 | ∂l/∂θ : gradient 573 | 574 | η: learning rate 575 | 576 | hj: sum of squares of all the previous θj's gradients 577 | 578 | when updating coefficent, we can adjust the scale by mutiplying 1/√h. 579 | 580 | But as iteration going on, h will be very large, making updating step becomes very small. 581 | 582 | ### RMSProp 583 | 584 | **RMSProp** can optimize this problem.RMSProp uses an exponential weighted average to eliminate swings in gradient descent: a larger derivative of a dimension means a larger exponential weighted average, and a smaller derivative means a smaller exponential weighted average. This ensures that the derivatives of each dimension are of the same order of magnitude, thus reducing swings: 585 | 586 | ![equation](https://latex.codecogs.com/gif.latex?%5Cbegin%7Balign%7D%20h_j%20%5Cleftarrow%20%5Cbeta%20h_j%20+%281-%5Cbeta%29%28%5Cfrac%7B%5Cpartial%20l%7D%7B%5Cpartial%20%5Ctheta%20_j%7D%29%5E%7B2%7D%20%5C%5C%20%5Ctheta%20_j%20%5Cleftarrow%20%5Ctheta%20_j%20-%20%5Ceta%20%5Cfrac%7B1%7D%7B%5Csqrt%7Bh_j+c%7D%7D%5Cfrac%7B%5Cpartial%20l%7D%7B%5Cpartial%20%5Ctheta%20_j%7D%20%5Cend%7Balign%7D) 587 | 588 | √hj can be 0 some times, so we add a small value c to √hj 589 | 590 | ***RMSProp code here*** 591 | 592 | ``` 593 | def RMSprop(x, y, lr=0.01, iter_count=500, batch_size=4, beta=0.9): 594 | length, features = x.shape 595 | data = np.column_stack((x, np.ones((length, 1)))) 596 | w = np.zeros((features + 1, 1)) 597 | h, eta = 0, 10e-7 598 | start, end = 0, batch_size 599 | for i in range(iter_count): 600 | # calculate gradient 601 | dw = np.sum((np.dot(data[start:end], w) - y[start:end]) * data[start:end], axis=0) / length 602 | # calculate sum of square of gradients 603 | h = beta * h + (1 - beta) * np.dot(dw, dw) 604 | # update w 605 | w = w - (lr / np.sqrt(eta + h)) * dw.reshape((features + 1, 1)) 606 | 607 | start = (start + batch_size) % length 608 | if start > length: 609 | start -= length 610 | end = (end + batch_size) % length 611 | if end > length: 612 | end -= length 613 | return w 614 | ``` 615 | ### Adam 616 | 617 | ***Adam*** is another powerful optimizer.It not only saved the sum of square of history gradients(h) but also save sum of history gradients(m, known as momentum): 618 | 619 | ![equation](https://latex.codecogs.com/gif.latex?%5Cbegin%7Balign%7D%20m_j%20%5Cleftarrow%20%5Cbeta_1%20m_j%20+%281-%5Cbeta%20_1%29%28%5Cfrac%7B%5Cpartial%20l%7D%7B%5Cpartial%20%5Ctheta%20_j%7D%29%5Cnonumber%5C%5C%20h_j%20%5Cleftarrow%20%5Cbeta_2%20h_j%20+%281-%5Cbeta%20_2%29%28%5Cfrac%7B%5Cpartial%20l%7D%7B%5Cpartial%20%5Ctheta%20_j%7D%29%5E%7B2%7D%5Cnonumber%20%5Cend%7Balign%7D) 620 | 621 | If m and h are initialized to the 0 vectors, they are biased to 0, so bias correction is done to offset these biases by calculating the bias corrected m and h: 622 | 623 | ![equation](https://latex.codecogs.com/gif.latex?%5Cbegin%7Balign%7D%20m_j%20%5Cleftarrow%20%5Cfrac%7Bm_j%7D%7B1-%5Cbeta%20_1%5E%7Bt%7D%7D%5Cnonumber%5C%5C%20h_j%20%5Cleftarrow%20%5Cfrac%7Bh_j%7D%7B1-%5Cbeta%20_2%5E%7Bt%7D%7D%5Cnonumber%20%5Cend%7Balign%7D) 624 | 625 | t means t th iteration. 626 | 627 | ![equation](https://latex.codecogs.com/gif.latex?%5Ctheta%20_j%20%5Cleftarrow%20%5Ctheta%20_j%20-%20%5Ceta%20%5Cfrac%7B1%7D%7B%5Csqrt%7Bh_j+c%7D%7Dm_j) 628 | 629 | 630 | ***Adam code here*** 631 | 632 | ``` 633 | def Adam(x, y, lr=0.01, iter_count=500, batch_size=4, beta1=0.9,beta2 = 0.999): 634 | length, features = x.shape 635 | data = np.column_stack((x, np.ones((length, 1)))) 636 | w = np.zeros((features + 1, 1)) 637 | m, h,eta = 0, 0,10e-7 638 | start, end = 0, batch_size 639 | for i in range(iter_count): 640 | # calculate gradient 641 | dw = np.sum((np.dot(data[start:end], w) - y[start:end]) * data[start:end], axis=0) / length 642 | # calculate sums 643 | m = beta1 * m + (1 - beta1) * dw 644 | h = beta2 * h + (1 - beta2) * np.dot(dw, dw) 645 | # bias correction 646 | m = m/(1- beta1) 647 | h = h/(1- beta2) 648 | # update w 649 | w = w - (lr / np.sqrt(eta + h)) * m.reshape((features + 1, 1)) 650 | 651 | start = (start + batch_size) % length 652 | if start > length: 653 | start -= length 654 | end = (end + batch_size) % length 655 | if end > length: 656 | end -= length 657 | return w 658 | ``` 659 | 660 | ### how to choose optimizer 661 | 662 | By far, the most popular models are SGDM and Adam. 663 | 664 | ![image](https://imgconvert.csdnimg.cn/aHR0cHM6Ly9pbnRyYW5ldHByb3h5LmFsaXBheS5jb20vc2t5bGFyay9sYXJrLzAvMjAyMC9wbmcvOTMwNC8xNTk4NTIzMDIxNTA5LTMyNTI1OGIwLTI5NzItNGNiNy04MDhkLTg4OTQ0Mzk0MWE3ZC5wbmc?x-oss-process=image/format,png#align=left&display=inline&height=302&margin=%5Bobject%20Object%5D&name=image.png&originHeight=604&originWidth=1074&size=76636&status=done&style=none&width=537wZw#pic_center) 665 | 666 | This graph illustrates that SGDM is always used in computer vision whereas Adam are popular in NLP. 667 | 668 | 669 | ### optimize Adam and SGDM 670 | 671 | For Adam, there are ***SWATS***,***AMSGrad***,***AdaBound***,and ***AdamW***. 672 | 673 | For SGDM, there are ***SGDMW***,***SGDNM*** 674 | 675 | #### SWATS 676 | 677 | combine Adam and SGDM: 678 | 679 | ![image](https://imgconvert.csdnimg.cn/aHR0cHM6Ly9pbnRyYW5ldHByb3h5LmFsaXBheS5jb20vc2t5bGFyay9sYXJrLzAvMjAyMC9wbmcvOTMwNC8xNTk4NTI0MjE0NzU0LTkwY2VlMmE0LTFiZWYtNGRhNC1hY2M5LTljYjVhMjE2ZTBmMS5wbmc?x-oss-process=image/format,png#align=left&display=inline&height=146&margin=%5Bobject%20Object%5D&name=image.png&originHeight=292&originWidth=1066&size=43146&status=done&style=none&width=533wZw#pic_center) 680 | 681 | #### AMSGrad 682 | 683 | optimize Adam in changing the way to update ***hj***: 684 | 685 | ![equation](https://latex.codecogs.com/gif.latex?%5Cbegin%7Balign%7D%20m_j%20%5E%7Bi%7D%5Cleftarrow%20%5Cbeta_1%20m_j%20%5E%7Bi-1%7D+%281-%5Cbeta%20_1%29%28%5Cfrac%7B%5Cpartial%20l%7D%7B%5Cpartial%20%5Ctheta%20_j%7D%29%5Cnonumber%5C%5C%20h_j%5E%7Bi%7D%20%5Cleftarrow%20%5Cbeta_2%20h_%7Bj%7D%5E%7Bi-1%7D+%281-%5Cbeta%20_2%29%28%5Cfrac%7B%5Cpartial%20l%7D%7B%5Cpartial%20%5Ctheta%20_j%7D%29%5E%7B2%7D%5Cnonumber%5C%5C%20AMSGrad%20%3A%20h_j%5E%7Bi%7D%3D%20max%28h_j%5E%7Bi%7D%2Ch_%7Bj%7D%5E%7Bi-1%7D%29%5Cnonumber%20%5Cend%7Balign%7D) 686 | 687 | AMSGrad makes learning rate monotone decrease and waives small gradients. 688 | 689 | #### AdaBound 690 | 691 | AdaBound limits learning rate in a certain scale: 692 | 693 | ![equation](https://latex.codecogs.com/gif.latex?%5Cbegin%7Balign%7D%20m_j%20%5E%7Bi%7D%5Cleftarrow%20%5Cbeta_1%20m_j%20%5E%7Bi-1%7D+%281-%5Cbeta%20_1%29%28%5Cfrac%7B%5Cpartial%20l%7D%7B%5Cpartial%20%5Ctheta%20_j%7D%29%5Cnonumber%5C%5C%20h_j%5E%7Bi%7D%20%5Cleftarrow%20%5Cbeta_2%20h_%7Bj%7D%5E%7Bi-1%7D+%281-%5Cbeta%20_2%29%28%5Cfrac%7B%5Cpartial%20l%7D%7B%5Cpartial%20%5Ctheta%20_j%7D%29%5E%7B2%7D%5Cnonumber%5C%5C%20AMSBound%20%3A%20%5Ctheta%20_j%20%5Cleftarrow%20%5Ctheta%20_j%20-%20clip%28%5Ceta%20%5Cfrac%7B1%7D%7B%5Csqrt%7Bh_j+c%7D%7D%29m_j%20%5Cnonumber%5C%5C%20where%20%3A%20clip%28x%29%3D%20np.clip%28x%2C0.1-%5Cfrac%7B0.1%7D%7B%281-%5Cbeta%20_2%29t+1%7D%2C0.1+%5Cfrac%7B0.1%7D%7B%281-%5Cbeta%20_2%29t%7D%29%5Cnonumber%20%5Cend%7Balign%7D) 694 | 695 | 696 | 697 | #### AdamW 698 | 699 | Add weight decay to Adam: 700 | 701 | ![equation](https://latex.codecogs.com/gif.latex?%5Cbegin%7Balign%7D%20m_j%20%5E%7Bi%7D%5Cleftarrow%20%5Cbeta_1%20m_j%20%5E%7Bi-1%7D+%281-%5Cbeta%20_1%29%28%5Cfrac%7B%5Cpartial%20l%7D%7B%5Cpartial%20%5Ctheta%20_j%7D%29%5Cnonumber%5C%5C%20h_j%5E%7Bi%7D%20%5Cleftarrow%20%5Cbeta_2%20h_%7Bj%7D%5E%7Bi-1%7D+%281-%5Cbeta%20_2%29%28%5Cfrac%7B%5Cpartial%20l%7D%7B%5Cpartial%20%5Ctheta%20_j%7D%29%5E%7B2%7D%5Cnonumber%5C%5C%20AdamW%20%3A%20%5Ctheta%20_j%20%5Cleftarrow%20%5Ctheta%20_j%20-%20%5Ceta%20%28%5Cfrac%7B1%7D%7B%5Csqrt%7Bh_j+c%7D%7Dm_j+%5Cgamma%20%5Ctheta%20_j%20%29%5Cnonumber%20%5Cend%7Balign%7D) 702 | 703 | 704 | 705 | #### SGDMW 706 | 707 | Add weight decay to SGDM: 708 | 709 | ![equation](https://latex.codecogs.com/gif.latex?%5Cbegin%7Balign%7D%20m_j%20%5Cleftarrow%20%5Clambda%20m_%7Bj%7D%20+%20%5Ceta%20%5Cfrac%7B%5Cpartial%20l%7D%7B%5Cpartial%20%5Ctheta%20_j%7D%5Cnonumber%20%5C%5C%20%5Ctheta%20_j%20%5Cleftarrow%20%5Ctheta%20_j%20+%20m_j+%20%5Cgamma%20%5Ctheta%20_j%5Cnonumber%20%5Cend%7Balign%7D) 710 | 711 | where m0 = 0, λ is momentum's coefficent, η is learing rate, γ is weight decay coefficent. 712 | 713 | 714 | #### SGDMN 715 | 716 | SGDMN(SGDM with Nesterov) is aimed to solve local optimal problem.When local optimal problem happend, SGDMN will do an additional calculation to determine whether to stop iteration: 717 | 718 | ![equation](https://latex.codecogs.com/gif.latex?%5Cbegin%7Balign%7D%20m_j%20%5Cleftarrow%20%5Clambda%20m_%7Bj%7D%20%5Cnonumber%20%5C%5C%20%5Ctheta%20_j%20%5Cleftarrow%20%5Ctheta%20_j%20+%20m_j%20%5Cnonumber%5C%5C%20m_j%20%5Cleftarrow%20%5Clambda%20m_%7Bj%7D%20+%20%5Ceta%20%5Cfrac%7B%5Cpartial%20l%7D%7B%5Cpartial%20%5Ctheta%20_j%7D%5Cnonumber%20%5C%5C%20%5Ctheta%20_j%20%5Cleftarrow%20%5Ctheta%20_j%20+%20m_j%20%5Cnonumber%20%5Cend%7Balign%7D) 719 | 720 | 721 | #### NAdam 722 | 723 | Nadam(Adam with Nesterov) is aimed to solve local optimal problem.When local optimal problem happend, NAdam will do an additional calculation to determine whether to stop iteration: 724 | 725 | ![equation](https://latex.codecogs.com/gif.latex?%5Cbegin%7Balign%7D%20m_j%20%5Cleftarrow%20%5Cbeta_1%20m_j%5Cnonumber%5C%5C%20h_j%20%5Cleftarrow%20%5Cbeta_2%20h_j%20%5Cnonumber%5C%5C%20m_j%20%5Cleftarrow%20%5Cfrac%7Bm_j%7D%7B1-%5Cbeta%20_1%5E%7Bt%7D%7D%5Cnonumber%5C%5C%20h_j%20%5Cleftarrow%20%5Cfrac%7Bh_j%7D%7B1-%5Cbeta%20_2%5E%7Bt%7D%7D%5Cnonumber%5C%5C%20%5Ctheta%20_j%20%5Cleftarrow%20%5Ctheta%20_j%20-%20%5Ceta%20%5Cfrac%7B1%7D%7B%5Csqrt%7Bh_j+c%7D%7Dm_j%20%5Cnonumber%5C%5C%20m_j%20%5Cleftarrow%20%5Cbeta_1%20m_j%20+%281-%5Cbeta%20_1%29%28%5Cfrac%7B%5Cpartial%20l%7D%7B%5Cpartial%20%5Ctheta%20_j%7D%29%5Cnonumber%5C%5C%20h_j%20%5Cleftarrow%20%5Cbeta_2%20h_j%20+%281-%5Cbeta%20_2%29%28%5Cfrac%7B%5Cpartial%20l%7D%7B%5Cpartial%20%5Ctheta%20_j%7D%29%5E%7B2%7D%5Cnonumber%5C%5C%20m_j%20%5Cleftarrow%20%5Clambda%20m_%7Bj%7D%20+%20%5Ceta%20%5Cfrac%7B%5Cpartial%20l%7D%7B%5Cpartial%20%5Ctheta%20_j%7D%5Cnonumber%20%5C%5C%20%5Ctheta%20_j%20%5Cleftarrow%20%5Ctheta%20_j%20-%20%5Ceta%20%5Cfrac%7B1%7D%7B%5Csqrt%7Bh_j+c%7D%7Dm_j%20%5Cnonumber%20%5Cend%7Balign%7D) 726 | 727 | 728 | comparation between these optimizers ,lets see the differenes of those optimizers: 729 | 730 | ![image](https://miro.medium.com/max/892/1*63HMdMyw_XDcNkRCQ1nrpw.png) 731 | 732 | 733 | 734 | ### Now I will share a kaggle project based on ML: 735 | 736 | ***kaggel link :*** 737 | 738 | https://www.kaggle.com/uciml/mushroom-classification 739 | 740 | 741 | This project is mushroom classification. There are about 2 categories (edible or poisonous) and 8124 records (52% edible and 48% poisonous). 742 | For this project, I have tried three machine learning models: 743 | SVM , Random forest and logistic regression. 744 | All three models have good performances: 745 | Compare recall, precision, f1-score for both class 746 | F1_score > 95% 747 | Accuracy > 95% 748 | 749 | This dataset includes descriptions of hypothetical samples corresponding to 23 species of gilled mushrooms in the Agaricus 750 | and Lepiota Family Mushroom drawn from The Audubon Society Field Guide to North American Mushrooms (1981). 751 | Each species is identified as definitely edible, definitely poisonous, or of unknown edibility and not recommended. 752 | This latter class was combined with the poisonous one. The Guide clearly states that there is no simple rule for determining the edibility of a mushroom; 753 | no rule like "leaflets three, let it be'' for Poisonous Oak and Ivy. 754 | 755 | ***Code*** 756 | 757 | #### Check [here](https://github.com/gnayoaixgnaw/machine_learning_project/blob/main/mashroom/cs677project.ipynb) 758 | 759 | 760 | 761 | 762 | 763 | -------------------------------------------------------------------------------- /image/1: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /image/CodeCogsEqn (1).gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gnayoaixgnaw/machine_learning_project/6b39513bea4b16288fb5e5ad06b4d7c81c1ecd72/image/CodeCogsEqn (1).gif -------------------------------------------------------------------------------- /image/adagrad1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gnayoaixgnaw/machine_learning_project/6b39513bea4b16288fb5e5ad06b4d7c81c1ecd72/image/adagrad1.png -------------------------------------------------------------------------------- /image/adaptivelr.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gnayoaixgnaw/machine_learning_project/6b39513bea4b16288fb5e5ad06b4d7c81c1ecd72/image/adaptivelr.png -------------------------------------------------------------------------------- /image/derivative.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gnayoaixgnaw/machine_learning_project/6b39513bea4b16288fb5e5ad06b4d7c81c1ecd72/image/derivative.png -------------------------------------------------------------------------------- /image/derivative1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gnayoaixgnaw/machine_learning_project/6b39513bea4b16288fb5e5ad06b4d7c81c1ecd72/image/derivative1.png -------------------------------------------------------------------------------- /image/derivative2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gnayoaixgnaw/machine_learning_project/6b39513bea4b16288fb5e5ad06b4d7c81c1ecd72/image/derivative2.png -------------------------------------------------------------------------------- /image/imbd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gnayoaixgnaw/machine_learning_project/6b39513bea4b16288fb5e5ad06b4d7c81c1ecd72/image/imbd.png -------------------------------------------------------------------------------- /image/lr_large.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gnayoaixgnaw/machine_learning_project/6b39513bea4b16288fb5e5ad06b4d7c81c1ecd72/image/lr_large.png -------------------------------------------------------------------------------- /image/lr_small.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gnayoaixgnaw/machine_learning_project/6b39513bea4b16288fb5e5ad06b4d7c81c1ecd72/image/lr_small.png -------------------------------------------------------------------------------- /image/mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gnayoaixgnaw/machine_learning_project/6b39513bea4b16288fb5e5ad06b4d7c81c1ecd72/image/mnist.png -------------------------------------------------------------------------------- /image/pseudocode1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gnayoaixgnaw/machine_learning_project/6b39513bea4b16288fb5e5ad06b4d7c81c1ecd72/image/pseudocode1.png -------------------------------------------------------------------------------- /image/pseudocode2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gnayoaixgnaw/machine_learning_project/6b39513bea4b16288fb5e5ad06b4d7c81c1ecd72/image/pseudocode2.png -------------------------------------------------------------------------------- /mashroom/cs677-project.pptx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gnayoaixgnaw/machine_learning_project/6b39513bea4b16288fb5e5ad06b4d7c81c1ecd72/mashroom/cs677-project.pptx -------------------------------------------------------------------------------- /mashroom/cs677project.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 2, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import pandas as pd\n", 10 | "import numpy as np\n", 11 | "\n", 12 | "from sklearn.decomposition import PCA\n", 13 | "from sklearn.pipeline import make_pipeline\n", 14 | "from sklearn.model_selection import train_test_split\n", 15 | "from sklearn.model_selection import GridSearchCV\n", 16 | "from sklearn.svm import SVC\n", 17 | "from sklearn.metrics import classification_report\n", 18 | "from sklearn.metrics import f1_score\n", 19 | "from sklearn.metrics import confusion_matrix\n", 20 | "from sklearn.metrics import classification_report\n", 21 | "from sklearn.neighbors import KNeighborsClassifier\n", 22 | "from sklearn.linear_model import LogisticRegression\n", 23 | "from sklearn.linear_model import SGDClassifier\n", 24 | "from sklearn.ensemble import RandomForestClassifier\n", 25 | "from sklearn.model_selection import cross_val_score\n", 26 | "\n", 27 | "\n", 28 | "\n", 29 | "import matplotlib.pyplot as plt\n", 30 | "\n" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": 3, 36 | "metadata": {}, 37 | "outputs": [], 38 | "source": [ 39 | "raw_data = pd.read_csv('mushrooms.csv')\n", 40 | "dropna_data = raw_data.dropna(axis=0, how='any')" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": 4, 46 | "metadata": {}, 47 | "outputs": [ 48 | { 49 | "data": { 50 | "text/plain": [ 51 | "array(['class', 'cap-shape', 'cap-surface', 'cap-color', 'bruises',\n", 52 | " 'odor', 'gill-attachment', 'gill-spacing', 'gill-size',\n", 53 | " 'gill-color', 'stalk-shape', 'stalk-root',\n", 54 | " 'stalk-surface-above-ring', 'stalk-surface-below-ring',\n", 55 | " 'stalk-color-above-ring', 'stalk-color-below-ring', 'veil-type',\n", 56 | " 'veil-color', 'ring-number', 'ring-type', 'spore-print-color',\n", 57 | " 'population', 'habitat'], dtype=object)" 58 | ] 59 | }, 60 | "execution_count": 4, 61 | "metadata": {}, 62 | "output_type": "execute_result" 63 | } 64 | ], 65 | "source": [ 66 | "columns = dropna_data.columns.values\n", 67 | "columns\n", 68 | "# Attribute Information: (classes: edible=e, poisonous=p)\n", 69 | "\n", 70 | "# cap-shape: bell=b,conical=c,convex=x,flat=f, knobbed=k,sunken=s\n", 71 | "\n", 72 | "# cap-surface: fibrous=f,grooves=g,scaly=y,smooth=s\n", 73 | "\n", 74 | "# cap-color: brown=n,buff=b,cinnamon=c,gray=g,green=r,pink=p,purple=u,red=e,white=w,yellow=y\n", 75 | "\n", 76 | "# bruises: bruises=t,no=f\n", 77 | "\n", 78 | "# odor: almond=a,anise=l,creosote=c,fishy=y,foul=f,musty=m,none=n,pungent=p,spicy=s\n", 79 | "\n", 80 | "# gill-attachment: attached=a,descending=d,free=f,notched=n\n", 81 | "\n", 82 | "# gill-spacing: close=c,crowded=w,distant=d\n", 83 | "\n", 84 | "# gill-size: broad=b,narrow=n\n", 85 | "\n", 86 | "# gill-color: black=k,brown=n,buff=b,chocolate=h,gray=g, green=r,orange=o,pink=p,purple=u,red=e,white=w,yellow=y\n", 87 | "\n", 88 | "# stalk-shape: enlarging=e,tapering=t\n", 89 | "\n", 90 | "# stalk-root: bulbous=b,club=c,cup=u,equal=e,rhizomorphs=z,rooted=r,missing=?\n", 91 | "\n", 92 | "# stalk-surface-above-ring: fibrous=f,scaly=y,silky=k,smooth=s\n", 93 | "\n", 94 | "# stalk-surface-below-ring: fibrous=f,scaly=y,silky=k,smooth=s\n", 95 | "\n", 96 | "# stalk-color-above-ring: brown=n,buff=b,cinnamon=c,gray=g,orange=o,pink=p,red=e,white=w,yellow=y\n", 97 | "\n", 98 | "# stalk-color-below-ring: brown=n,buff=b,cinnamon=c,gray=g,orange=o,pink=p,red=e,white=w,yellow=y\n", 99 | "\n", 100 | "# veil-type: partial=p,universal=u\n", 101 | "\n", 102 | "# veil-color: brown=n,orange=o,white=w,yellow=y\n", 103 | "\n", 104 | "# ring-number: none=n,one=o,two=t\n", 105 | "\n", 106 | "# ring-type: cobwebby=c,evanescent=e,flaring=f,large=l,none=n,pendant=p,sheathing=s,zone=z\n", 107 | "\n", 108 | "# spore-print-color: black=k,brown=n,buff=b,chocolate=h,green=r,orange=o,purple=u,white=w,yellow=y\n", 109 | "\n", 110 | "# population: abundant=a,clustered=c,numerous=n,scattered=s,several=v,solitary=y\n", 111 | "\n", 112 | "# habitat: grasses=g,leaves=l,meadows=m,paths=p,urban=u,waste=w,woods=d" 113 | ] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": 5, 118 | "metadata": {}, 119 | "outputs": [ 120 | { 121 | "name": "stderr", 122 | "output_type": "stream", 123 | "text": [ 124 | ":9: MatplotlibDeprecationWarning: Using a string of single character colors as a color sequence is deprecated. Use an explicit list instead.\n", 125 | " autolabel(plt.bar(range(len(num_list)), num_list, color='rgb', tick_label=name_list))\n" 126 | ] 127 | }, 128 | { 129 | "data": { 130 | "image/png": "\n", 131 | "text/plain": [ 132 | "
" 133 | ] 134 | }, 135 | "metadata": { 136 | "needs_background": "light" 137 | }, 138 | "output_type": "display_data" 139 | } 140 | ], 141 | "source": [ 142 | "def autolabel(rects):\n", 143 | " for rect in rects:\n", 144 | " height = rect.get_height()\n", 145 | " plt.text(rect.get_x()+rect.get_width()/2.- 0.2, 1.03*height, '%s' % int(height))\n", 146 | " \n", 147 | " \n", 148 | "name_list = dropna_data['cap-color'].value_counts().index\n", 149 | "num_list = dropna_data['cap-color'].value_counts()\n", 150 | "autolabel(plt.bar(range(len(num_list)), num_list, color='rgb', tick_label=name_list))\n", 151 | "plt.show()\n" 152 | ] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "execution_count": 6, 157 | "metadata": {}, 158 | "outputs": [ 159 | { 160 | "name": "stderr", 161 | "output_type": "stream", 162 | "text": [ 163 | ":3: MatplotlibDeprecationWarning: Using a string of single character colors as a color sequence is deprecated. Use an explicit list instead.\n", 164 | " autolabel(plt.bar(range(len(num_list)), num_list, color='rgb', tick_label=name_list))\n" 165 | ] 166 | }, 167 | { 168 | "data": { 169 | "image/png": "\n", 170 | "text/plain": [ 171 | "
" 172 | ] 173 | }, 174 | "metadata": { 175 | "needs_background": "light" 176 | }, 177 | "output_type": "display_data" 178 | } 179 | ], 180 | "source": [ 181 | "name_list = dropna_data['odor'].value_counts().index\n", 182 | "num_list = dropna_data['odor'].value_counts()\n", 183 | "autolabel(plt.bar(range(len(num_list)), num_list, color='rgb', tick_label=name_list))\n", 184 | "plt.show()\n" 185 | ] 186 | }, 187 | { 188 | "cell_type": "code", 189 | "execution_count": 7, 190 | "metadata": {}, 191 | "outputs": [], 192 | "source": [ 193 | "def encoding(x):\n", 194 | " temp = list(dropna_data[x].value_counts().index)\n", 195 | " dic = {}\n", 196 | " index = 0\n", 197 | " result = []\n", 198 | " for i in temp:\n", 199 | " dic[i] = index\n", 200 | " index +=1\n", 201 | " for j in dropna_data[x]:\n", 202 | " result.append(dic[j])\n", 203 | " return result\n", 204 | " " 205 | ] 206 | }, 207 | { 208 | "cell_type": "code", 209 | "execution_count": 8, 210 | "metadata": {}, 211 | "outputs": [], 212 | "source": [ 213 | "data = pd.DataFrame()\n", 214 | "for i in columns:\n", 215 | " number_list = encoding(i)\n", 216 | " data[i] = number_list" 217 | ] 218 | }, 219 | { 220 | "cell_type": "code", 221 | "execution_count": 25, 222 | "metadata": {}, 223 | "outputs": [ 224 | { 225 | "data": { 226 | "text/html": [ 227 | "
\n", 228 | "\n", 241 | "\n", 242 | " \n", 243 | " \n", 244 | " \n", 245 | " \n", 246 | " \n", 247 | " \n", 248 | " \n", 249 | " \n", 250 | " \n", 251 | " \n", 252 | " \n", 253 | " \n", 254 | " \n", 255 | " \n", 256 | " \n", 257 | " \n", 258 | " \n", 259 | " \n", 260 | " \n", 261 | " \n", 262 | " \n", 263 | " \n", 264 | " \n", 265 | " \n", 266 | " \n", 267 | " \n", 268 | " \n", 269 | " \n", 270 | " \n", 271 | " \n", 272 | " \n", 273 | " \n", 274 | " \n", 275 | " \n", 276 | " \n", 277 | " \n", 278 | " \n", 279 | " \n", 280 | " \n", 281 | " \n", 282 | " \n", 283 | " \n", 284 | " \n", 285 | " \n", 286 | " \n", 287 | " \n", 288 | " \n", 289 | " \n", 290 | " \n", 291 | " \n", 292 | " \n", 293 | " \n", 294 | " \n", 295 | " \n", 296 | " \n", 297 | " \n", 298 | " \n", 299 | " \n", 300 | " \n", 301 | " \n", 302 | " \n", 303 | " \n", 304 | " \n", 305 | " \n", 306 | " \n", 307 | " \n", 308 | " \n", 309 | " \n", 310 | " \n", 311 | " \n", 312 | " \n", 313 | " \n", 314 | " \n", 315 | " \n", 316 | " \n", 317 | " \n", 318 | " \n", 319 | " \n", 320 | " \n", 321 | " \n", 322 | " \n", 323 | " \n", 324 | " \n", 325 | " \n", 326 | " \n", 327 | " \n", 328 | " \n", 329 | " \n", 330 | " \n", 331 | " \n", 332 | " \n", 333 | " \n", 334 | " \n", 335 | " \n", 336 | " \n", 337 | " \n", 338 | " \n", 339 | " \n", 340 | " \n", 341 | " \n", 342 | " \n", 343 | " \n", 344 | " \n", 345 | " \n", 346 | " \n", 347 | " \n", 348 | " \n", 349 | " \n", 350 | " \n", 351 | " \n", 352 | " \n", 353 | " \n", 354 | " \n", 355 | " \n", 356 | " \n", 357 | " \n", 358 | " \n", 359 | " \n", 360 | " \n", 361 | " \n", 362 | " \n", 363 | " \n", 364 | " \n", 365 | " \n", 366 | " \n", 367 | " \n", 368 | " \n", 369 | " \n", 370 | " \n", 371 | " \n", 372 | " \n", 373 | " \n", 374 | " \n", 375 | " \n", 376 | " \n", 377 | " \n", 378 | " \n", 379 | " \n", 380 | " \n", 381 | " \n", 382 | " \n", 383 | " \n", 384 | " \n", 385 | " \n", 386 | " \n", 387 | " \n", 388 | " \n", 389 | " \n", 390 | " \n", 391 | " \n", 392 | " \n", 393 | " \n", 394 | " \n", 395 | " \n", 396 | " \n", 397 | " \n", 398 | " \n", 399 | " \n", 400 | " \n", 401 | " \n", 402 | " \n", 403 | " \n", 404 | " \n", 405 | " \n", 406 | " \n", 407 | " \n", 408 | " \n", 409 | " \n", 410 | " \n", 411 | " \n", 412 | " \n", 413 | " \n", 414 | " \n", 415 | " \n", 416 | " \n", 417 | " \n", 418 | " \n", 419 | " \n", 420 | " \n", 421 | " \n", 422 | " \n", 423 | " \n", 424 | " \n", 425 | " \n", 426 | " \n", 427 | " \n", 428 | " \n", 429 | " \n", 430 | " \n", 431 | " \n", 432 | " \n", 433 | " \n", 434 | " \n", 435 | " \n", 436 | " \n", 437 | " \n", 438 | " \n", 439 | " \n", 440 | " \n", 441 | " \n", 442 | " \n", 443 | " \n", 444 | " \n", 445 | " \n", 446 | " \n", 447 | " \n", 448 | " \n", 449 | " \n", 450 | " \n", 451 | " \n", 452 | " \n", 453 | " \n", 454 | " \n", 455 | " \n", 456 | " \n", 457 | " \n", 458 | " \n", 459 | " \n", 460 | " \n", 461 | " \n", 462 | " \n", 463 | " \n", 464 | " \n", 465 | " \n", 466 | " \n", 467 | " \n", 468 | " \n", 469 | " \n", 470 | " \n", 471 | " \n", 472 | " \n", 473 | " \n", 474 | " \n", 475 | " \n", 476 | " \n", 477 | " \n", 478 | " \n", 479 | " \n", 480 | " \n", 481 | " \n", 482 | " \n", 483 | " \n", 484 | " \n", 485 | " \n", 486 | " \n", 487 | " \n", 488 | " \n", 489 | " \n", 490 | " \n", 491 | " \n", 492 | " \n", 493 | " \n", 494 | " \n", 495 | " \n", 496 | " \n", 497 | " \n", 498 | " \n", 499 | " \n", 500 | " \n", 501 | " \n", 502 | " \n", 503 | " \n", 504 | " \n", 505 | " \n", 506 | " \n", 507 | " \n", 508 | " \n", 509 | " \n", 510 | " \n", 511 | " \n", 512 | " \n", 513 | " \n", 514 | " \n", 515 | " \n", 516 | " \n", 517 | " \n", 518 | " \n", 519 | " \n", 520 | " \n", 521 | " \n", 522 | " \n", 523 | " \n", 524 | " \n", 525 | " \n", 526 | " \n", 527 | " \n", 528 | " \n", 529 | " \n", 530 | " \n", 531 | " \n", 532 | " \n", 533 | " \n", 534 | " \n", 535 | " \n", 536 | " \n", 537 | " \n", 538 | " \n", 539 | " \n", 540 | " \n", 541 | " \n", 542 | " \n", 543 | " \n", 544 | " \n", 545 | " \n", 546 | " \n", 547 | " \n", 548 | " \n", 549 | " \n", 550 | " \n", 551 | " \n", 552 | " \n", 553 | " \n", 554 | " \n", 555 | " \n", 556 | " \n", 557 | " \n", 558 | " \n", 559 | " \n", 560 | " \n", 561 | " \n", 562 | " \n", 563 | " \n", 564 | " \n", 565 | " \n", 566 | " \n", 567 | " \n", 568 | " \n", 569 | " \n", 570 | " \n", 571 | " \n", 572 | " \n", 573 | " \n", 574 | " \n", 575 | " \n", 576 | " \n", 577 | " \n", 578 | " \n", 579 | " \n", 580 | " \n", 581 | " \n", 582 | " \n", 583 | " \n", 584 | " \n", 585 | " \n", 586 | " \n", 587 | " \n", 588 | " \n", 589 | " \n", 590 | " \n", 591 | " \n", 592 | " \n", 593 | " \n", 594 | " \n", 595 | " \n", 596 | " \n", 597 | " \n", 598 | " \n", 599 | " \n", 600 | " \n", 601 | " \n", 602 | " \n", 603 | " \n", 604 | " \n", 605 | " \n", 606 | " \n", 607 | " \n", 608 | " \n", 609 | " \n", 610 | " \n", 611 | " \n", 612 | " \n", 613 | " \n", 614 | " \n", 615 | " \n", 616 | " \n", 617 | " \n", 618 | " \n", 619 | " \n", 620 | " \n", 621 | " \n", 622 | " \n", 623 | " \n", 624 | " \n", 625 | " \n", 626 | " \n", 627 | " \n", 628 | " \n", 629 | " \n", 630 | " \n", 631 | " \n", 632 | " \n", 633 | " \n", 634 | " \n", 635 | " \n", 636 | " \n", 637 | " \n", 638 | " \n", 639 | " \n", 640 | " \n", 641 | " \n", 642 | " \n", 643 | " \n", 644 | " \n", 645 | " \n", 646 | " \n", 647 | " \n", 648 | " \n", 649 | " \n", 650 | " \n", 651 | " \n", 652 | " \n", 653 | " \n", 654 | " \n", 655 | " \n", 656 | " \n", 657 | " \n", 658 | " \n", 659 | " \n", 660 | " \n", 661 | " \n", 662 | " \n", 663 | " \n", 664 | " \n", 665 | " \n", 666 | " \n", 667 | " \n", 668 | " \n", 669 | " \n", 670 | " \n", 671 | " \n", 672 | " \n", 673 | " \n", 674 | " \n", 675 | " \n", 676 | " \n", 677 | " \n", 678 | " \n", 679 | " \n", 680 | " \n", 681 | " \n", 682 | " \n", 683 | " \n", 684 | " \n", 685 | " \n", 686 | " \n", 687 | " \n", 688 | " \n", 689 | " \n", 690 | " \n", 691 | " \n", 692 | " \n", 693 | " \n", 694 | " \n", 695 | " \n", 696 | " \n", 697 | " \n", 698 | " \n", 699 | " \n", 700 | " \n", 701 | " \n", 702 | " \n", 703 | " \n", 704 | " \n", 705 | " \n", 706 | " \n", 707 | " \n", 708 | " \n", 709 | " \n", 710 | " \n", 711 | " \n", 712 | " \n", 713 | " \n", 714 | " \n", 715 | " \n", 716 | " \n", 717 | " \n", 718 | " \n", 719 | " \n", 720 | " \n", 721 | " \n", 722 | " \n", 723 | " \n", 724 | " \n", 725 | " \n", 726 | " \n", 727 | " \n", 728 | " \n", 729 | " \n", 730 | " \n", 731 | " \n", 732 | " \n", 733 | " \n", 734 | " \n", 735 | " \n", 736 | " \n", 737 | " \n", 738 | " \n", 739 | " \n", 740 | " \n", 741 | " \n", 742 | " \n", 743 | " \n", 744 | " \n", 745 | " \n", 746 | " \n", 747 | " \n", 748 | " \n", 749 | " \n", 750 | " \n", 751 | " \n", 752 | " \n", 753 | " \n", 754 | " \n", 755 | " \n", 756 | " \n", 757 | " \n", 758 | " \n", 759 | " \n", 760 | " \n", 761 | " \n", 762 | " \n", 763 | " \n", 764 | " \n", 765 | " \n", 766 | " \n", 767 | " \n", 768 | " \n", 769 | " \n", 770 | " \n", 771 | " \n", 772 | " \n", 773 | " \n", 774 | " \n", 775 | " \n", 776 | " \n", 777 | " \n", 778 | " \n", 779 | " \n", 780 | " \n", 781 | " \n", 782 | " \n", 783 | " \n", 784 | " \n", 785 | " \n", 786 | " \n", 787 | " \n", 788 | " \n", 789 | " \n", 790 | " \n", 791 | " \n", 792 | " \n", 793 | " \n", 794 | " \n", 795 | " \n", 796 | " \n", 797 | " \n", 798 | " \n", 799 | " \n", 800 | " \n", 801 | " \n", 802 | " \n", 803 | " \n", 804 | " \n", 805 | " \n", 806 | " \n", 807 | " \n", 808 | " \n", 809 | " \n", 810 | " \n", 811 | " \n", 812 | " \n", 813 | " \n", 814 | " \n", 815 | " \n", 816 | " \n", 817 | " \n", 818 | " \n", 819 | " \n", 820 | " \n", 821 | " \n", 822 | "
classcap-shapecap-surfacecap-colorbruisesodorgill-attachmentgill-spacinggill-sizegill-color...stalk-surface-below-ringstalk-color-above-ringstalk-color-below-ringveil-typeveil-colorring-numberring-typespore-print-colorpopulationhabitat
class1.000000-0.034451-0.1591550.011735-0.5015300.323569-0.129200-0.3483870.540024-0.334999...0.1823420.1336580.133722NaN-0.120766-0.1522610.556515-0.025285-0.4463070.078160
cap-shape-0.0344511.000000-0.0580420.029639-0.0572870.0171340.126426-0.0058660.083962-0.016882...-0.059609-0.019245-0.004047NaN0.1392360.195835-0.082190-0.1199420.0940090.216553
cap-surface-0.159155-0.0580421.000000-0.009187-0.088630-0.2019570.0053640.312855-0.1689270.136293...-0.060242-0.016559-0.012579NaN0.006500-0.004239-0.0446340.1382180.092017-0.135708
cap-color0.0117350.029639-0.0091871.0000000.1087730.220801-0.1508550.115640-0.0883960.139617...0.050341-0.093657-0.081589NaN-0.1403290.2134190.0380230.1243210.1954870.155114
bruises-0.501530-0.057287-0.0886300.1087731.0000000.067419-0.137359-0.299473-0.3695960.179959...-0.303546-0.168188-0.185653NaN-0.1256060.014367-0.6851190.0943090.0717360.010061
odor0.3235690.017134-0.2019570.2208010.0674191.000000-0.065285-0.1402060.433342-0.100251...0.036040-0.188523-0.203276NaN-0.113019-0.088455-0.083359-0.097657-0.0116200.158634
gill-attachment-0.1292000.1264260.005364-0.150855-0.137359-0.0652851.000000-0.071489-0.1089840.303533...-0.0583850.4604620.444346NaN0.8422300.049642-0.0731490.3628900.1930650.146291
gill-spacing-0.348387-0.0058660.3128550.115640-0.299473-0.140206-0.0714891.000000-0.1083330.122796...0.120523-0.275012-0.247624NaN-0.0342350.206233-0.038015-0.0681130.433383-0.115027
gill-size0.5400240.083962-0.168927-0.088396-0.3695960.433342-0.108984-0.1083331.000000-0.482295...-0.074886-0.274632-0.228869NaN-0.074866-0.1901120.071459-0.464968-0.3964800.118962
gill-color-0.334999-0.0168820.1362930.1396170.179959-0.1002510.3035330.122796-0.4822951.000000...-0.0449050.2410090.227199NaN0.2785620.126014-0.0728660.4161070.3751370.075706
stalk-shape0.1020190.116854-0.0264950.330198-0.0993640.3270200.186485-0.080895-0.2145760.335438...0.1757480.3209230.342241NaN0.1705290.3253050.2248380.3411070.2581710.302504
stalk-root-0.2483900.190482-0.1476570.0926780.0082540.4560760.0456550.2233670.0830550.042577...0.243842-0.292372-0.293069NaN0.0392230.042445-0.134897-0.2387320.4636740.354307
stalk-surface-above-ring0.323350-0.0459120.0806440.014094-0.468764-0.158253-0.0904520.217483-0.054378-0.041328...0.3555470.1319060.151545NaN-0.048532-0.0197220.4953710.0846450.024179-0.022573
stalk-surface-below-ring0.182342-0.059609-0.0602420.050341-0.3035460.036040-0.0583850.120523-0.074886-0.044905...1.0000000.0765660.126544NaN-0.0584350.0258970.4177830.0623370.066446-0.042043
stalk-color-above-ring0.133658-0.019245-0.016559-0.093657-0.168188-0.1885230.460462-0.275012-0.2746320.241009...0.0765661.0000000.666745NaN0.4270220.1341380.4102770.3532480.0621740.037326
stalk-color-below-ring0.133722-0.004047-0.012579-0.081589-0.185653-0.2032760.444346-0.247624-0.2288690.227199...0.1265440.6667451.000000NaN0.4121310.1233200.4096580.3161180.0408000.036261
veil-typeNaNNaNNaNNaNNaNNaNNaNNaNNaNNaN...NaNNaNNaNNaNNaNNaNNaNNaNNaNNaN
veil-color-0.1207660.1392360.006500-0.140329-0.125606-0.1130190.842230-0.034235-0.0748660.278562...-0.0584350.4270220.412131NaN1.000000-0.042328-0.1172320.3365810.1741950.158578
ring-number-0.1522610.195835-0.0042390.2134190.014367-0.0884550.0496420.206233-0.1901120.126014...0.0258970.1341380.123320NaN-0.0423281.0000000.006528-0.2193890.4120980.225839
ring-type0.556515-0.082190-0.0446340.038023-0.685119-0.083359-0.073149-0.0380150.071459-0.072866...0.4177830.4102770.409658NaN-0.1172320.0065281.0000000.102536-0.070331-0.011085
spore-print-color-0.025285-0.1199420.1382180.1243210.094309-0.0976570.362890-0.068113-0.4649680.416107...0.0623370.3532480.316118NaN0.336581-0.2193890.1025361.0000000.013230-0.072205
population-0.4463070.0940090.0920170.1954870.071736-0.0116200.1930650.433383-0.3964800.375137...0.0664460.0621740.040800NaN0.1741950.412098-0.0703310.0132301.0000000.301272
habitat0.0781600.216553-0.1357080.1551140.0100610.1586340.146291-0.1150270.1189620.075706...-0.0420430.0373260.036261NaN0.1585780.225839-0.011085-0.0722050.3012721.000000
\n", 823 | "

23 rows × 23 columns

\n", 824 | "
" 825 | ], 826 | "text/plain": [ 827 | " class cap-shape cap-surface cap-color \\\n", 828 | "class 1.000000 -0.034451 -0.159155 0.011735 \n", 829 | "cap-shape -0.034451 1.000000 -0.058042 0.029639 \n", 830 | "cap-surface -0.159155 -0.058042 1.000000 -0.009187 \n", 831 | "cap-color 0.011735 0.029639 -0.009187 1.000000 \n", 832 | "bruises -0.501530 -0.057287 -0.088630 0.108773 \n", 833 | "odor 0.323569 0.017134 -0.201957 0.220801 \n", 834 | "gill-attachment -0.129200 0.126426 0.005364 -0.150855 \n", 835 | "gill-spacing -0.348387 -0.005866 0.312855 0.115640 \n", 836 | "gill-size 0.540024 0.083962 -0.168927 -0.088396 \n", 837 | "gill-color -0.334999 -0.016882 0.136293 0.139617 \n", 838 | "stalk-shape 0.102019 0.116854 -0.026495 0.330198 \n", 839 | "stalk-root -0.248390 0.190482 -0.147657 0.092678 \n", 840 | "stalk-surface-above-ring 0.323350 -0.045912 0.080644 0.014094 \n", 841 | "stalk-surface-below-ring 0.182342 -0.059609 -0.060242 0.050341 \n", 842 | "stalk-color-above-ring 0.133658 -0.019245 -0.016559 -0.093657 \n", 843 | "stalk-color-below-ring 0.133722 -0.004047 -0.012579 -0.081589 \n", 844 | "veil-type NaN NaN NaN NaN \n", 845 | "veil-color -0.120766 0.139236 0.006500 -0.140329 \n", 846 | "ring-number -0.152261 0.195835 -0.004239 0.213419 \n", 847 | "ring-type 0.556515 -0.082190 -0.044634 0.038023 \n", 848 | "spore-print-color -0.025285 -0.119942 0.138218 0.124321 \n", 849 | "population -0.446307 0.094009 0.092017 0.195487 \n", 850 | "habitat 0.078160 0.216553 -0.135708 0.155114 \n", 851 | "\n", 852 | " bruises odor gill-attachment gill-spacing \\\n", 853 | "class -0.501530 0.323569 -0.129200 -0.348387 \n", 854 | "cap-shape -0.057287 0.017134 0.126426 -0.005866 \n", 855 | "cap-surface -0.088630 -0.201957 0.005364 0.312855 \n", 856 | "cap-color 0.108773 0.220801 -0.150855 0.115640 \n", 857 | "bruises 1.000000 0.067419 -0.137359 -0.299473 \n", 858 | "odor 0.067419 1.000000 -0.065285 -0.140206 \n", 859 | "gill-attachment -0.137359 -0.065285 1.000000 -0.071489 \n", 860 | "gill-spacing -0.299473 -0.140206 -0.071489 1.000000 \n", 861 | "gill-size -0.369596 0.433342 -0.108984 -0.108333 \n", 862 | "gill-color 0.179959 -0.100251 0.303533 0.122796 \n", 863 | "stalk-shape -0.099364 0.327020 0.186485 -0.080895 \n", 864 | "stalk-root 0.008254 0.456076 0.045655 0.223367 \n", 865 | "stalk-surface-above-ring -0.468764 -0.158253 -0.090452 0.217483 \n", 866 | "stalk-surface-below-ring -0.303546 0.036040 -0.058385 0.120523 \n", 867 | "stalk-color-above-ring -0.168188 -0.188523 0.460462 -0.275012 \n", 868 | "stalk-color-below-ring -0.185653 -0.203276 0.444346 -0.247624 \n", 869 | "veil-type NaN NaN NaN NaN \n", 870 | "veil-color -0.125606 -0.113019 0.842230 -0.034235 \n", 871 | "ring-number 0.014367 -0.088455 0.049642 0.206233 \n", 872 | "ring-type -0.685119 -0.083359 -0.073149 -0.038015 \n", 873 | "spore-print-color 0.094309 -0.097657 0.362890 -0.068113 \n", 874 | "population 0.071736 -0.011620 0.193065 0.433383 \n", 875 | "habitat 0.010061 0.158634 0.146291 -0.115027 \n", 876 | "\n", 877 | " gill-size gill-color ... \\\n", 878 | "class 0.540024 -0.334999 ... \n", 879 | "cap-shape 0.083962 -0.016882 ... \n", 880 | "cap-surface -0.168927 0.136293 ... \n", 881 | "cap-color -0.088396 0.139617 ... \n", 882 | "bruises -0.369596 0.179959 ... \n", 883 | "odor 0.433342 -0.100251 ... \n", 884 | "gill-attachment -0.108984 0.303533 ... \n", 885 | "gill-spacing -0.108333 0.122796 ... \n", 886 | "gill-size 1.000000 -0.482295 ... \n", 887 | "gill-color -0.482295 1.000000 ... \n", 888 | "stalk-shape -0.214576 0.335438 ... \n", 889 | "stalk-root 0.083055 0.042577 ... \n", 890 | "stalk-surface-above-ring -0.054378 -0.041328 ... \n", 891 | "stalk-surface-below-ring -0.074886 -0.044905 ... \n", 892 | "stalk-color-above-ring -0.274632 0.241009 ... \n", 893 | "stalk-color-below-ring -0.228869 0.227199 ... \n", 894 | "veil-type NaN NaN ... \n", 895 | "veil-color -0.074866 0.278562 ... \n", 896 | "ring-number -0.190112 0.126014 ... \n", 897 | "ring-type 0.071459 -0.072866 ... \n", 898 | "spore-print-color -0.464968 0.416107 ... \n", 899 | "population -0.396480 0.375137 ... \n", 900 | "habitat 0.118962 0.075706 ... \n", 901 | "\n", 902 | " stalk-surface-below-ring stalk-color-above-ring \\\n", 903 | "class 0.182342 0.133658 \n", 904 | "cap-shape -0.059609 -0.019245 \n", 905 | "cap-surface -0.060242 -0.016559 \n", 906 | "cap-color 0.050341 -0.093657 \n", 907 | "bruises -0.303546 -0.168188 \n", 908 | "odor 0.036040 -0.188523 \n", 909 | "gill-attachment -0.058385 0.460462 \n", 910 | "gill-spacing 0.120523 -0.275012 \n", 911 | "gill-size -0.074886 -0.274632 \n", 912 | "gill-color -0.044905 0.241009 \n", 913 | "stalk-shape 0.175748 0.320923 \n", 914 | "stalk-root 0.243842 -0.292372 \n", 915 | "stalk-surface-above-ring 0.355547 0.131906 \n", 916 | "stalk-surface-below-ring 1.000000 0.076566 \n", 917 | "stalk-color-above-ring 0.076566 1.000000 \n", 918 | "stalk-color-below-ring 0.126544 0.666745 \n", 919 | "veil-type NaN NaN \n", 920 | "veil-color -0.058435 0.427022 \n", 921 | "ring-number 0.025897 0.134138 \n", 922 | "ring-type 0.417783 0.410277 \n", 923 | "spore-print-color 0.062337 0.353248 \n", 924 | "population 0.066446 0.062174 \n", 925 | "habitat -0.042043 0.037326 \n", 926 | "\n", 927 | " stalk-color-below-ring veil-type veil-color \\\n", 928 | "class 0.133722 NaN -0.120766 \n", 929 | "cap-shape -0.004047 NaN 0.139236 \n", 930 | "cap-surface -0.012579 NaN 0.006500 \n", 931 | "cap-color -0.081589 NaN -0.140329 \n", 932 | "bruises -0.185653 NaN -0.125606 \n", 933 | "odor -0.203276 NaN -0.113019 \n", 934 | "gill-attachment 0.444346 NaN 0.842230 \n", 935 | "gill-spacing -0.247624 NaN -0.034235 \n", 936 | "gill-size -0.228869 NaN -0.074866 \n", 937 | "gill-color 0.227199 NaN 0.278562 \n", 938 | "stalk-shape 0.342241 NaN 0.170529 \n", 939 | "stalk-root -0.293069 NaN 0.039223 \n", 940 | "stalk-surface-above-ring 0.151545 NaN -0.048532 \n", 941 | "stalk-surface-below-ring 0.126544 NaN -0.058435 \n", 942 | "stalk-color-above-ring 0.666745 NaN 0.427022 \n", 943 | "stalk-color-below-ring 1.000000 NaN 0.412131 \n", 944 | "veil-type NaN NaN NaN \n", 945 | "veil-color 0.412131 NaN 1.000000 \n", 946 | "ring-number 0.123320 NaN -0.042328 \n", 947 | "ring-type 0.409658 NaN -0.117232 \n", 948 | "spore-print-color 0.316118 NaN 0.336581 \n", 949 | "population 0.040800 NaN 0.174195 \n", 950 | "habitat 0.036261 NaN 0.158578 \n", 951 | "\n", 952 | " ring-number ring-type spore-print-color \\\n", 953 | "class -0.152261 0.556515 -0.025285 \n", 954 | "cap-shape 0.195835 -0.082190 -0.119942 \n", 955 | "cap-surface -0.004239 -0.044634 0.138218 \n", 956 | "cap-color 0.213419 0.038023 0.124321 \n", 957 | "bruises 0.014367 -0.685119 0.094309 \n", 958 | "odor -0.088455 -0.083359 -0.097657 \n", 959 | "gill-attachment 0.049642 -0.073149 0.362890 \n", 960 | "gill-spacing 0.206233 -0.038015 -0.068113 \n", 961 | "gill-size -0.190112 0.071459 -0.464968 \n", 962 | "gill-color 0.126014 -0.072866 0.416107 \n", 963 | "stalk-shape 0.325305 0.224838 0.341107 \n", 964 | "stalk-root 0.042445 -0.134897 -0.238732 \n", 965 | "stalk-surface-above-ring -0.019722 0.495371 0.084645 \n", 966 | "stalk-surface-below-ring 0.025897 0.417783 0.062337 \n", 967 | "stalk-color-above-ring 0.134138 0.410277 0.353248 \n", 968 | "stalk-color-below-ring 0.123320 0.409658 0.316118 \n", 969 | "veil-type NaN NaN NaN \n", 970 | "veil-color -0.042328 -0.117232 0.336581 \n", 971 | "ring-number 1.000000 0.006528 -0.219389 \n", 972 | "ring-type 0.006528 1.000000 0.102536 \n", 973 | "spore-print-color -0.219389 0.102536 1.000000 \n", 974 | "population 0.412098 -0.070331 0.013230 \n", 975 | "habitat 0.225839 -0.011085 -0.072205 \n", 976 | "\n", 977 | " population habitat \n", 978 | "class -0.446307 0.078160 \n", 979 | "cap-shape 0.094009 0.216553 \n", 980 | "cap-surface 0.092017 -0.135708 \n", 981 | "cap-color 0.195487 0.155114 \n", 982 | "bruises 0.071736 0.010061 \n", 983 | "odor -0.011620 0.158634 \n", 984 | "gill-attachment 0.193065 0.146291 \n", 985 | "gill-spacing 0.433383 -0.115027 \n", 986 | "gill-size -0.396480 0.118962 \n", 987 | "gill-color 0.375137 0.075706 \n", 988 | "stalk-shape 0.258171 0.302504 \n", 989 | "stalk-root 0.463674 0.354307 \n", 990 | "stalk-surface-above-ring 0.024179 -0.022573 \n", 991 | "stalk-surface-below-ring 0.066446 -0.042043 \n", 992 | "stalk-color-above-ring 0.062174 0.037326 \n", 993 | "stalk-color-below-ring 0.040800 0.036261 \n", 994 | "veil-type NaN NaN \n", 995 | "veil-color 0.174195 0.158578 \n", 996 | "ring-number 0.412098 0.225839 \n", 997 | "ring-type -0.070331 -0.011085 \n", 998 | "spore-print-color 0.013230 -0.072205 \n", 999 | "population 1.000000 0.301272 \n", 1000 | "habitat 0.301272 1.000000 \n", 1001 | "\n", 1002 | "[23 rows x 23 columns]" 1003 | ] 1004 | }, 1005 | "execution_count": 25, 1006 | "metadata": {}, 1007 | "output_type": "execute_result" 1008 | } 1009 | ], 1010 | "source": [ 1011 | "data.corr()" 1012 | ] 1013 | }, 1014 | { 1015 | "cell_type": "code", 1016 | "execution_count": 32, 1017 | "metadata": {}, 1018 | "outputs": [], 1019 | "source": [ 1020 | "def visualize_data(x):\n", 1021 | " data = x.values\n", 1022 | " \n", 1023 | " fig = plt.figure(figsize=(20, 20))\n", 1024 | " ax = fig.add_subplot(111)\n", 1025 | " \n", 1026 | " heatmap = ax.pcolor(data,cmap = plt.cm.RdYlGn)\n", 1027 | " fig.colorbar(heatmap)\n", 1028 | " ax.set_xticks(np.arange(data.shape[0] + 0.5),minor = False)\n", 1029 | " ax.set_yticks(np.arange(data.shape[1] + 0.5),minor = False)\n", 1030 | " ax.invert_yaxis()\n", 1031 | " ax.xaxis.tick_top()\n", 1032 | " \n", 1033 | " column_labels = x.columns\n", 1034 | " row_labels = x.index\n", 1035 | " \n", 1036 | " ax.set_xticklabels(column_labels)\n", 1037 | " ax.set_yticklabels(row_labels)\n", 1038 | " \n", 1039 | " plt.xticks(rotation = 90)\n", 1040 | " heatmap.set_clim(-1,1)\n", 1041 | " plt.tight_layout()\n", 1042 | " plt.show()\n", 1043 | " \n" 1044 | ] 1045 | }, 1046 | { 1047 | "cell_type": "code", 1048 | "execution_count": 33, 1049 | "metadata": {}, 1050 | "outputs": [ 1051 | { 1052 | "data": { 1053 | "image/png": "\n", 1054 | "text/plain": [ 1055 | "
" 1056 | ] 1057 | }, 1058 | "metadata": { 1059 | "needs_background": "light" 1060 | }, 1061 | "output_type": "display_data" 1062 | } 1063 | ], 1064 | "source": [ 1065 | "visualize_data(data.corr())" 1066 | ] 1067 | }, 1068 | { 1069 | "cell_type": "code", 1070 | "execution_count": 45, 1071 | "metadata": {}, 1072 | "outputs": [], 1073 | "source": [ 1074 | "data.drop(['bruises'], axis=1)\n", 1075 | "data_data = data[columns[1:]].values\n", 1076 | "data_label= data[columns[0]].values \n" 1077 | ] 1078 | }, 1079 | { 1080 | "cell_type": "code", 1081 | "execution_count": 46, 1082 | "metadata": {}, 1083 | "outputs": [ 1084 | { 1085 | "data": { 1086 | "text/plain": [ 1087 | "8124" 1088 | ] 1089 | }, 1090 | "execution_count": 46, 1091 | "metadata": {}, 1092 | "output_type": "execute_result" 1093 | } 1094 | ], 1095 | "source": [ 1096 | "len(data_data)" 1097 | ] 1098 | }, 1099 | { 1100 | "cell_type": "code", 1101 | "execution_count": null, 1102 | "metadata": {}, 1103 | "outputs": [], 1104 | "source": [] 1105 | }, 1106 | { 1107 | "cell_type": "code", 1108 | "execution_count": 47, 1109 | "metadata": {}, 1110 | "outputs": [], 1111 | "source": [ 1112 | " #edible=e=0, poisonous=p=1" 1113 | ] 1114 | }, 1115 | { 1116 | "cell_type": "code", 1117 | "execution_count": 49, 1118 | "metadata": {}, 1119 | "outputs": [], 1120 | "source": [ 1121 | "x_train, x_test,y_train,y_test = train_test_split(data_data, data_label, test_size=0.5,random_state=15)\n", 1122 | "\n", 1123 | "pca = PCA(n_components=20,whiten=True, random_state= 32)\n", 1124 | "\n" 1125 | ] 1126 | }, 1127 | { 1128 | "cell_type": "code", 1129 | "execution_count": 50, 1130 | "metadata": {}, 1131 | "outputs": [ 1132 | { 1133 | "name": "stdout", 1134 | "output_type": "stream", 1135 | "text": [ 1136 | "(22,)\n" 1137 | ] 1138 | } 1139 | ], 1140 | "source": [ 1141 | "print(x_train[1].shape)" 1142 | ] 1143 | }, 1144 | { 1145 | "cell_type": "code", 1146 | "execution_count": 51, 1147 | "metadata": {}, 1148 | "outputs": [ 1149 | { 1150 | "data": { 1151 | "text/plain": [ 1152 | "(4062, 22)" 1153 | ] 1154 | }, 1155 | "execution_count": 51, 1156 | "metadata": {}, 1157 | "output_type": "execute_result" 1158 | } 1159 | ], 1160 | "source": [ 1161 | "x_train.shape" 1162 | ] 1163 | }, 1164 | { 1165 | "cell_type": "code", 1166 | "execution_count": 60, 1167 | "metadata": {}, 1168 | "outputs": [ 1169 | { 1170 | "name": "stdout", 1171 | "output_type": "stream", 1172 | "text": [ 1173 | "-----------------svm------------------------------\n", 1174 | "(4062, 22) (4062,)\n", 1175 | "{'svc__C': 7, 'svc__gamma': 0.1}\n", 1176 | " precision recall f1-score support\n", 1177 | "\n", 1178 | " edible 1.00 1.00 1.00 2133\n", 1179 | " poisonous 1.00 1.00 1.00 1929\n", 1180 | "\n", 1181 | " accuracy 1.00 4062\n", 1182 | " macro avg 1.00 1.00 1.00 4062\n", 1183 | "weighted avg 1.00 1.00 1.00 4062\n", 1184 | "\n", 1185 | "0.9997407311381903\n", 1186 | "[[2133 0]\n", 1187 | " [ 1 1928]]\n" 1188 | ] 1189 | } 1190 | ], 1191 | "source": [ 1192 | "print('-----------------svm------------------------------')\n", 1193 | "\n", 1194 | "svc = SVC(kernel='rbf',class_weight='balanced')\n", 1195 | "\n", 1196 | "\n", 1197 | "\n", 1198 | "model = make_pipeline(pca, svc)\n", 1199 | "\n", 1200 | "\n", 1201 | "param_grid = {'svc__C': [7,9,11],\n", 1202 | " 'svc__gamma': [0.1,0.5,1]}\n", 1203 | "grid = GridSearchCV(model, param_grid)\n", 1204 | "print(x_train.shape, y_train.shape)\n", 1205 | "grid.fit(x_train, y_train) #bulid model\n", 1206 | "\n", 1207 | "print(grid.best_params_) #show the best parameters\n", 1208 | "\n", 1209 | "model = grid.best_estimator_ #pick up the best model\n", 1210 | "\n", 1211 | "yfit = model.predict(x_test) #use the best model\n", 1212 | "# fig, ax = plt.subplots(4, 6)\n", 1213 | "# for i, axi in enumerate(ax.flat):\n", 1214 | "# axi.imshow(x_test[i].reshape(22), cmap='bone')\n", 1215 | "# axi.set(xticks=[], yticks=[])\n", 1216 | "# axi.set_ylabel(faces.target_names[yfit[i]].split()[-1],\n", 1217 | "# color='black' if yfit[i] == y_test[i] else 'red')\n", 1218 | "# fig.suptitle('Predicted Names; Incorrect Labels in Red', size=14)\n", 1219 | "# fig.show()\n", 1220 | "\n", 1221 | "\n", 1222 | "print(classification_report(y_test, yfit,\n", 1223 | " target_names=['edible', 'poisonous']))\n", 1224 | "f1 = f1_score(y_test, yfit, average='binary')\n", 1225 | "a = confusion_matrix(y_test, yfit)\n", 1226 | "print(f1)\n", 1227 | "print(a)" 1228 | ] 1229 | }, 1230 | { 1231 | "cell_type": "code", 1232 | "execution_count": 53, 1233 | "metadata": {}, 1234 | "outputs": [ 1235 | { 1236 | "name": "stdout", 1237 | "output_type": "stream", 1238 | "text": [ 1239 | "-----------------randomforest------------------------------\n", 1240 | "(4062, 22) (4062,)\n", 1241 | "{'n_estimators': 3}\n", 1242 | " precision recall f1-score support\n", 1243 | "\n", 1244 | " edible 1.00 1.00 1.00 2133\n", 1245 | " poisonous 1.00 1.00 1.00 1929\n", 1246 | "\n", 1247 | " accuracy 1.00 4062\n", 1248 | " macro avg 1.00 1.00 1.00 4062\n", 1249 | "weighted avg 1.00 1.00 1.00 4062\n", 1250 | "\n", 1251 | "0.9987023098883987\n", 1252 | "[[2133 0]\n", 1253 | " [ 5 1924]]\n" 1254 | ] 1255 | } 1256 | ], 1257 | "source": [ 1258 | "print('-----------------randomforest------------------------------')\n", 1259 | "\n", 1260 | "clf = RandomForestClassifier(random_state=0)\n", 1261 | "\n", 1262 | "\n", 1263 | "model1 = make_pipeline(pca, clf)\n", 1264 | "\n", 1265 | "# param_grid = [\n", 1266 | "# {'n_estimators': [1,2,3], 'max_features': [20,21,22]}\n", 1267 | "# ]\n", 1268 | "\n", 1269 | "param_grid = [\n", 1270 | "{'n_estimators': [1,2,3]}\n", 1271 | "]\n", 1272 | "\n", 1273 | "\n", 1274 | "grid_search = GridSearchCV(clf, param_grid, cv=5,\n", 1275 | " scoring='neg_mean_squared_error')\n", 1276 | "\n", 1277 | "print(x_train.shape, y_train.shape)\n", 1278 | "\n", 1279 | "grid_search.fit(x_train, y_train) #build model\n", 1280 | "\n", 1281 | "print(grid_search.best_params_) #show the best parameters\n", 1282 | "\n", 1283 | "model1 = grid_search.best_estimator_ #pick up the best model\n", 1284 | "\n", 1285 | "yfit = model1.predict(x_test) #use the best model\n", 1286 | "# fig, ax = plt.subplots(4, 6)\n", 1287 | "# for i, axi in enumerate(ax.flat):\n", 1288 | "# axi.imshow(x_test[i].reshape(22), cmap='bone')\n", 1289 | "# axi.set(xticks=[], yticks=[])\n", 1290 | "# axi.set_ylabel(faces.target_names[yfit[i]].split()[-1],\n", 1291 | "# color='black' if yfit[i] == y_test[i] else 'red')\n", 1292 | "# fig.suptitle('Predicted Names; Incorrect Labels in Red', size=14)\n", 1293 | "# fig.show()\n", 1294 | "\n", 1295 | "\n", 1296 | "print(classification_report(y_test, yfit,\n", 1297 | " target_names=['edible', 'poisonous']))\n", 1298 | "f1 = f1_score(y_test, yfit, average='binary')\n", 1299 | "a = confusion_matrix(y_test, yfit)\n", 1300 | "print(f1)\n", 1301 | "print(a)" 1302 | ] 1303 | }, 1304 | { 1305 | "cell_type": "code", 1306 | "execution_count": 54, 1307 | "metadata": {}, 1308 | "outputs": [ 1309 | { 1310 | "name": "stdout", 1311 | "output_type": "stream", 1312 | "text": [ 1313 | "[0.000953 0.00539851 0.06849629 0.00057451 0.27318584 0.\n", 1314 | " 0.00028443 0.11608979 0.09398737 0.09765874 0.00554243 0.01227082\n", 1315 | " 0.00332788 0.00028963 0.01112092 0. 0. 0.\n", 1316 | " 0.21442257 0.05097382 0.03397122 0.01145224]\n" 1317 | ] 1318 | } 1319 | ], 1320 | "source": [ 1321 | "print(model1.feature_importances_)" 1322 | ] 1323 | }, 1324 | { 1325 | "cell_type": "code", 1326 | "execution_count": 55, 1327 | "metadata": {}, 1328 | "outputs": [ 1329 | { 1330 | "name": "stdout", 1331 | "output_type": "stream", 1332 | "text": [ 1333 | "gill-attachment\n", 1334 | "stalk-surface-above-ring\n", 1335 | "population\n" 1336 | ] 1337 | } 1338 | ], 1339 | "source": [ 1340 | "print(columns[5+1])\n", 1341 | "print(columns[11+1])\n", 1342 | "print(columns[20+1])" 1343 | ] 1344 | }, 1345 | { 1346 | "cell_type": "code", 1347 | "execution_count": 56, 1348 | "metadata": {}, 1349 | "outputs": [ 1350 | { 1351 | "name": "stdout", 1352 | "output_type": "stream", 1353 | "text": [ 1354 | "-----------------lr------------------------------\n", 1355 | " precision recall f1-score support\n", 1356 | "\n", 1357 | " edible 0.99 0.99 0.99 2133\n", 1358 | " poisonous 0.99 0.98 0.99 1929\n", 1359 | "\n", 1360 | " accuracy 0.99 4062\n", 1361 | " macro avg 0.99 0.99 0.99 4062\n", 1362 | "weighted avg 0.99 0.99 0.99 4062\n", 1363 | "\n", 1364 | "0.9862301896596518\n", 1365 | "[[2111 22]\n", 1366 | " [ 31 1898]]\n" 1367 | ] 1368 | } 1369 | ], 1370 | "source": [ 1371 | "print('-----------------lr------------------------------')\n", 1372 | "\n", 1373 | "lr = LogisticRegression(class_weight='balanced',penalty = 'l2',C = 1)\n", 1374 | "\n", 1375 | "\n", 1376 | "model2 = make_pipeline(pca, lr)\n", 1377 | "\n", 1378 | "# param_grid = {}\n", 1379 | "\n", 1380 | "\n", 1381 | "# grid_search = GridSearchCV(clf, param_grid, cv=5,\n", 1382 | "# scoring='neg_mean_squared_error')\n", 1383 | "\n", 1384 | "# print(x_train.shape, y_train.shape)\n", 1385 | "\n", 1386 | "model2.fit(x_train, y_train) #bulid model\n", 1387 | "\n", 1388 | "\n", 1389 | "\n", 1390 | "yfit = model2.predict(x_test) #\n", 1391 | "# fig, ax = plt.subplots(4, 6)\n", 1392 | "# for i, axi in enumerate(ax.flat):\n", 1393 | "# axi.imshow(x_test[i].reshape(22), cmap='bone')\n", 1394 | "# axi.set(xticks=[], yticks=[])\n", 1395 | "# axi.set_ylabel(faces.target_names[yfit[i]].split()[-1],\n", 1396 | "# color='black' if yfit[i] == y_test[i] else 'red')\n", 1397 | "# fig.suptitle('Predicted Names; Incorrect Labels in Red', size=14)\n", 1398 | "# fig.show()\n", 1399 | "\n", 1400 | "\n", 1401 | "print(classification_report(y_test, yfit,\n", 1402 | " target_names=['edible', 'poisonous']))\n", 1403 | "f1 = f1_score(y_test, yfit, average='binary')\n", 1404 | "a = confusion_matrix(y_test, yfit)\n", 1405 | "print(f1)\n", 1406 | "print(a)" 1407 | ] 1408 | }, 1409 | { 1410 | "cell_type": "code", 1411 | "execution_count": 57, 1412 | "metadata": {}, 1413 | "outputs": [ 1414 | { 1415 | "name": "stdout", 1416 | "output_type": "stream", 1417 | "text": [ 1418 | "svm 10-fold correct prediction: 0.8977\n", 1419 | "randomforest 10-fold correct prediction: 0.9626\n", 1420 | "lr 10-fold correct prediction: 0.9633\n" 1421 | ] 1422 | } 1423 | ], 1424 | "source": [ 1425 | "svm_10_fold = cross_val_score(model, data_data, data_label, cv=10)\n", 1426 | "rf_10_fold = cross_val_score(model1, data_data, data_label, cv=10)\n", 1427 | "lr_10_fold = cross_val_score(model2, data_data, data_label, cv=10)\n", 1428 | "\n", 1429 | "\n", 1430 | "print('svm 10-fold correct prediction: {:4.4f}'.format(np.mean(svm_10_fold)))\n", 1431 | "print('randomforest 10-fold correct prediction: {:4.4f}'.format(np.mean(rf_10_fold)))\n", 1432 | "print('lr 10-fold correct prediction: {:4.4f}'.format(np.mean(lr_10_fold)))\n" 1433 | ] 1434 | }, 1435 | { 1436 | "cell_type": "code", 1437 | "execution_count": null, 1438 | "metadata": {}, 1439 | "outputs": [], 1440 | "source": [] 1441 | }, 1442 | { 1443 | "cell_type": "code", 1444 | "execution_count": null, 1445 | "metadata": {}, 1446 | "outputs": [], 1447 | "source": [] 1448 | } 1449 | ], 1450 | "metadata": { 1451 | "kernelspec": { 1452 | "display_name": "Python 3", 1453 | "language": "python", 1454 | "name": "python3" 1455 | }, 1456 | "language_info": { 1457 | "codemirror_mode": { 1458 | "name": "ipython", 1459 | "version": 3 1460 | }, 1461 | "file_extension": ".py", 1462 | "mimetype": "text/x-python", 1463 | "name": "python", 1464 | "nbconvert_exporter": "python", 1465 | "pygments_lexer": "ipython3", 1466 | "version": "3.8.3" 1467 | } 1468 | }, 1469 | "nbformat": 4, 1470 | "nbformat_minor": 4 1471 | } 1472 | -------------------------------------------------------------------------------- /mashroom/read.me: -------------------------------------------------------------------------------- 1 | 2 | --------------------------------------------------------------------------------