├── conf_test.png
├── README.md
└── main.py
/conf_test.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PB2204/Water-Quality/HEAD/conf_test.png
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Water Quality Analysis
2 |
3 | # Dataset Information
4 |
5 | Safe drinking water is essential to a healthy life. It is a fundamental human right. Healthy drinking water is vital as a health and development issue at a national, regional, and local level. In some regions, investing in water supply and sanitation can yield a net economic benefit since the reductions in adverse health effects and health care costs outweigh the costs of undertaking the interventions.
6 |
7 | ### Attribute Information:
8 |
9 | Input variables (based on physicochemical tests): \
10 | 1 - ph-> pH of water \
11 | 2 - Hardness-> Capacity of water to precipitate soap in mg/L \
12 | 3 - Solids-> Total dissolved solids in ppm \
13 | 4 - Chloramines-> Amount of Chloramines in ppm \
14 | 5 - Sulfate-> Amount of Sulfates dissolved in mg/L \
15 | 6 - Conductivity-> Electrical conductivity of water in μS/cm \
16 | 7 - Organic_carbon-> Amount of organic carbon in ppm \
17 | 8 - Trihalomethanes-> Amount of Trihalomethanes in μg/L \
18 | 9 - Turbidity-> Measure of light emiting property of water in NTU (Nephelometric Turbidity Units) \
19 |
20 | Output variable (based on sensory data): \
21 | 10 - Potability-> Indicates if water is safe for human consumption
22 |
23 |
24 | **Download link:** https://www.kaggle.com/adityakadiwal/water-potability
25 |
26 |
27 | # Libraries
28 |
29 |
30 |
pandas
31 | matplotlib
32 | seaborn
33 | plotly
34 | scikit-learn
35 | xgboost
36 |
37 | # Algorithms
38 | Logistic Regression
39 | K Nearest Neighbours
40 | Support Vector Machine
41 | Decision Tree
42 | Random Forest
43 | XGBoost
44 |
45 | **Best Model Accuracy:** 73.97658631
46 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | # pip install sklearn seaborn plotly matplotlib numpy pandas warnings xboost tqdm
2 |
3 | import seaborn as sns
4 | import plotly.express as px
5 | import matplotlib.pyplot as plt
6 | import plotly.graph_objects as go
7 | from tqdm import tqdm_notebook
8 | import plotly.figure_factory as ff
9 | import numpy as np
10 | import pandas as pd
11 |
12 | import warnings
13 | warnings.filterwarnings('ignore')
14 |
15 | plt.style.use('fivethirtyeight')
16 | # %matplotlib inline
17 | import os
18 | for dirname, _, filenames in os.walk('cc'):
19 | for filename in filenames:
20 | print(os.path.join(dirname, filename))
21 |
22 | # Reading The Data-Set
23 | data=pd.read_csv('C:/Users/rocks/OneDrive/Desktop/Projects/Water-Quality/water_potability.csv')
24 | data.head()
25 |
26 | # EDA
27 | '''
28 |
29 | * ph-> pH of water
30 | * Hardness-> Capacity of water to precipitate soap in mg/L
31 | * Solids-> Total dissolved solids in ppm
32 | * Chloramines-> Amount of Chloramines in ppm
33 | * Sulfate-> Amount of Sulfates dissolved in mg/L
34 | * Conductivity-> Electrical conductivity of water in μS/cm
35 | * Organic_carbon-> Amount of organic carbon in ppm
36 | * Trihalomethanes-> Amount of Trihalomethanes in μg/L
37 | * Turbidity-> Measure of light emiting property of water in NTU (Nephelometric Turbidity Units)
38 | * Potability-> Indicates if water is safe for human consumption
39 |
40 | '''
41 |
42 | # Describe The Data
43 | data.describe()
44 |
45 | # Information Of The Data
46 | data.info()
47 |
48 | print('There are {} data points and {} features in the data.'.format(data.shape[0],data.shape[1]))
49 |
50 |
51 | # Null Values
52 | sns.heatmap(data.isnull(),yticklabels=False,cbar=False,cmap='viridis')
53 |
54 | for i in data.columns:
55 | if data[i].isnull().sum()>0:
56 | print("There are {} null values in {} column.".format(data[i].isnull().sum(),i))
57 |
58 | '''
59 | # Handelling Null Values
60 | '''
61 |
62 | # ph
63 | data['ph'].describe()
64 |
65 | # Filling The Missing Values By Mean
66 | data['ph_mean']=data['ph'].fillna(data['ph'].mean())
67 |
68 | data['ph_mean'].isnull().sum()
69 |
70 | # Graphical Plotting
71 | fig = plt.figure()
72 | ax = fig.add_subplot(111)
73 | data['ph'].plot(kind='kde', ax=ax)
74 | data.ph_mean.plot(kind='kde', ax=ax, color='red')
75 | lines, labels = ax.get_legend_handles_labels()
76 | ax.legend(lines, labels, loc='best')
77 | plt.show()
78 |
79 | '''
80 | #### The Distribution Is Not Uniform
81 | '''
82 | # Filling The Data With Random Values
83 | def impute_nan(df,variable):
84 | df[variable+"_random"]=df[variable]
85 | ##It will have the random sample to fill the na
86 | random_sample=df[variable].dropna().sample(df[variable].isnull().sum(),random_state=0)
87 | ##pandas need to have same index in order to merge the dataset
88 | random_sample.index=df[df[variable].isnull()].index
89 | df.loc[df[variable].isnull(),variable+'_random']=random_sample
90 |
91 | # Uniform Distribution With Random Initialization
92 | impute_nan(data,"ph")
93 |
94 | # ph_random & ph_mean Graph Plotting
95 | fig = plt.figure()
96 | ax = fig.add_subplot(111)
97 | data['ph'].plot(kind='kde', ax=ax)
98 | data.ph_random.plot(kind='kde', ax=ax, color='green')
99 | data.ph_mean.plot(kind='kde', ax=ax, color='red')
100 | lines, labels = ax.get_legend_handles_labels()
101 | ax.legend(lines, labels, loc='best')
102 | plt.show()
103 |
104 | # Uniform Distribution With Random Initialization
105 | impute_nan(data,"Sulfate")
106 |
107 | # Sulfate_random Graphical Plotting
108 | fig = plt.figure()
109 | ax = fig.add_subplot(111)
110 | data['Sulfate'].plot(kind='kde', ax=ax)
111 | data["Sulfate_random"].plot(kind='kde', ax=ax, color='green')
112 | lines, labels = ax.get_legend_handles_labels()
113 | ax.legend(lines, labels, loc='best')
114 | plt.show()
115 |
116 | # Uniform Distribution With Random Initialization
117 | impute_nan(data,"Trihalomethanes")
118 |
119 | # Trihalomethanes Graphical Plotting
120 | fig = plt.figure()
121 | ax = fig.add_subplot(111)
122 | data['Trihalomethanes'].plot(kind='kde', ax=ax)
123 | data.Trihalomethanes_random.plot(kind='kde', ax=ax, color='green')
124 | lines, labels = ax.get_legend_handles_labels()
125 | ax.legend(lines, labels, loc='best')
126 | plt.show()
127 |
128 |
129 |
130 | data=data.drop(['ph','Sulfate','Trihalomethanes','ph_mean'],axis=1)
131 | data.isnull().sum()
132 |
133 | '''
134 | ## Check For Correlation
135 | '''
136 | # Graphical Representation
137 | plt.figure(figsize=(20, 17))
138 | matrix = np.triu(data.corr())
139 | sns.heatmap(data.corr(), annot=True,linewidth=.8, mask=matrix, cmap="rocket",cbar=False);
140 |
141 |
142 | '''
143 | # There Are No Correlated Columns Presebt In The Data
144 | '''
145 | # Graphical Representation
146 | sns.pairplot(data, hue="Potability", palette="husl");
147 |
148 |
149 |
150 | # Graphical Representation
151 | from tqdm import tqdm
152 |
153 | def distributionPlot(data):
154 | fig = plt.figure(figsize=(20, 20))
155 | num_columns = len(data.columns)
156 | num_rows = int(np.ceil(num_columns / 3))
157 |
158 | for i in tqdm(range(num_columns)):
159 | fig.add_subplot(num_rows, 3, i + 1)
160 | sns.distplot(data.iloc[:, i], color="lightcoral", rug=True)
161 |
162 | fig.tight_layout(pad=3)
163 |
164 | plot_data = data.drop(['Potability'], axis=1)
165 | distributionPlot(plot_data)
166 |
167 |
168 | # Hardness
169 | data['Hardness'].describe()
170 |
171 | # Distribution Plot Of Hardness Graph
172 | plt.figure(figsize = (16, 7))
173 | sns.distplot(data['Hardness'])
174 | plt.title('Distribution Plot Of Hardness\n', fontsize = 20)
175 | plt.show()
176 |
177 | # Hardness WRT Potability Graph
178 | # basic scatter plot
179 | fig = px.scatter(data,range(data['Hardness'].count()), sorted(data['Hardness']),
180 | color=data['Potability'],
181 | labels={
182 | 'x': "Count",
183 | 'y': "Hardness",
184 | 'color':'Potability'
185 |
186 | }, template = 'plotly_dark')
187 | fig.update_layout(title='Hardness WRT Potability')
188 | fig.show()
189 |
190 | # Plotly Dark Graph
191 | px.histogram(data_frame = data, x = 'Hardness', nbins = 10, color = 'Potability', marginal = 'box',
192 | template = 'plotly_dark')
193 |
194 | '''
195 | # Solids
196 | '''
197 | data['Solids'].describe()
198 |
199 | # Distribution Plot Of Solids Graph
200 | plt.figure(figsize = (16, 7))
201 | sns.distplot(data['Solids'])
202 | plt.title('Distribution Plot Of Solids\n', fontsize = 20)
203 | plt.show()
204 |
205 | # Potability Graph
206 | fig = px.scatter(data, sorted(data["Solids"]), range(data["Solids"].count()), color="Potability", facet_col="Potability",
207 | facet_row="Potability")
208 | fig.show()
209 |
210 | # Portability Plotly Dark Graph
211 | px.histogram(data_frame = data, x = 'Solids', nbins = 10, color = 'Potability', marginal = 'box',
212 | template = 'plotly_dark')
213 |
214 | # Hardness WRT Potability Graph
215 | # basic scatter plot
216 | fig = px.scatter(data,range(data['Solids'].count()), sorted(data['Solids']),
217 | color=data['Potability'],
218 | labels={
219 | 'x': "Count",
220 | 'y': "Hardness",
221 | 'color':'Potability'
222 |
223 | },
224 | color_continuous_scale=px.colors.sequential.tempo,
225 | template = 'plotly_dark')
226 | fig.update_layout(title='Hardness WRT Potability')
227 | fig.show()
228 |
229 |
230 | '''
231 | # Chloramines
232 | '''
233 | data['Chloramines'].describe()
234 |
235 | # Distribution Plot Of Chloramines Graph
236 | plt.figure(figsize = (16, 7))
237 | sns.distplot(data['Chloramines'])
238 | plt.title('Distribution Plot Of Chloramines\n', fontsize = 20)
239 | plt.show()
240 |
241 | # Chloramines WRT Potability Graph
242 | fig = px.line(x=range(data['Chloramines'].count()), y=sorted(data['Chloramines']),color=data['Potability'], labels={
243 | 'x': "Count",
244 | 'y': "Chloramines",
245 | 'color':'Potability'
246 |
247 | }, template = 'plotly_dark')
248 | fig.update_layout(title='Chloramines WRT Potability')
249 | fig.show()
250 |
251 | # Chloramines Graph
252 | fig = px.box(x = 'Chloramines', data_frame = data, template = 'plotly_dark')
253 | fig.update_layout(title='Chloramines')
254 | fig.show()
255 |
256 |
257 | '''
258 | # # Conductivity
259 | '''
260 | data["Conductivity"].describe()
261 |
262 | # Distribution Plot Of Conductivity Graph
263 | plt.figure(figsize = (16, 7))
264 | sns.distplot(data['Conductivity'])
265 | plt.title('Distribution Plot Of Conductivity\n', fontsize = 20)
266 | plt.show()
267 |
268 | # Conductivity WRT Potability Graph
269 | fig = px.bar(data, x=range(data['Conductivity'].count()),
270 | y=sorted(data['Conductivity']), labels={
271 | 'x': "Count",
272 | 'y': "Conductivity",
273 | 'color':'Potability'
274 |
275 | },
276 | color=data['Potability']
277 | ,template = 'plotly_dark')
278 | fig.update_layout(title='Conductivity WRT Potability')
279 | fig.show()
280 |
281 | # Conductivity Graph
282 | group_labels = ['distplot'] # name of the dataset
283 |
284 | fig = ff.create_distplot([data['Conductivity']], group_labels)
285 | fig.show()
286 |
287 |
288 |
289 | '''
290 | # Organic_carbon
291 | '''
292 | data['Organic_carbon'].describe()
293 |
294 | # Organic_carbon Graph
295 | group_labels = ['Organic_carbon'] # name of the dataset
296 |
297 | fig = ff.create_distplot([data['Organic_carbon']], group_labels)
298 | fig.show()
299 |
300 |
301 | # Number Of Passengers Per Age Group Graph
302 | dt_5=data[data['Organic_carbon']<5]
303 | dt_5_10=data[(data['Organic_carbon']>5)&(data['Organic_carbon']<10)]
304 | dt_10_15=data[(data['Organic_carbon']>10)&(data['Organic_carbon']<15)]
305 | dt_15_20=data[(data['Organic_carbon']>15)&(data['Organic_carbon']<20)]
306 | dt_20_25=data[(data['Organic_carbon']>20)&(data['Organic_carbon']<25)]
307 | dt_25=data[(data['Organic_carbon']>25)]
308 |
309 | x_Age = ['5', '5-10', '10-15', '15-20', '25+']
310 | y_Age = [len(dt_5.values), len(dt_5_10.values), len(dt_10_15.values), len(dt_15_20.values),
311 | len(dt_25.values)]
312 |
313 | px.bar(data_frame = data, x = x_Age, y = y_Age, color = x_Age, template = 'plotly_dark',
314 | title = 'Number Of Passengers Per Age Group')
315 |
316 |
317 | # Organic_carbon Organic_carbon Graph With Potability Hue
318 | sns.catplot(x = 'Organic_carbon', y = 'Organic_carbon', hue = 'Potability', data = data, kind = 'box',
319 | height = 5, aspect = 2)
320 | plt.show()
321 |
322 |
323 |
324 | '''
325 | # Turbidity
326 | '''
327 | data['Turbidity'].describe()
328 |
329 | # Turbidity Graph
330 | group_labels = ['Turbidity'] # name of the dataset
331 |
332 | fig = ff.create_distplot([data['Turbidity']], group_labels)
333 | fig.show()
334 |
335 | data['turbid_class']=data['Turbidity'].astype(int)
336 | data['turbid_class'].unique()
337 |
338 | # Turbidity turbidity_class Graph
339 | px.scatter(data_frame = data, x = 'Turbidity', y = 'turbid_class', color = 'Potability', template = 'plotly_dark')
340 |
341 | data=data.drop(['turbid_class'],axis=1)
342 |
343 |
344 | '''
345 | # ph_random
346 | '''
347 | data['ph_random'].describe()
348 |
349 | # ph_random Graph
350 | group_labels = ['ph_random'] # name of the dataset
351 |
352 | fig = ff.create_distplot([data['ph_random']], group_labels)
353 | fig.show()
354 |
355 | # ph_random & Portability Graph
356 | px.histogram(data_frame = data, x = 'ph_random', nbins = 10, color = 'Potability', marginal = 'box',
357 | template = 'plotly_dark')
358 |
359 |
360 | # --------------------------------
361 | fig = px.scatter(data, sorted(data["ph_random"]), range(data["ph_random"].count()), color="Potability", facet_col="Potability",
362 | facet_row="Potability")
363 | fig.show()
364 |
365 |
366 | '''
367 | # Sulfate_random
368 | '''
369 | data['Sulfate_random'].describe()
370 |
371 | # Sulfate_random Graph
372 | group_labels = ['distplot'] # name of the dataset
373 |
374 | fig = ff.create_distplot([data['Sulfate_random']], group_labels)
375 | fig.show()
376 |
377 | # Sulfate_random & Sulfate_random With Pottability Hue
378 | sns.catplot(x = 'Sulfate_random', y = 'Sulfate_random', hue = 'Potability', data = data, kind = 'box',
379 | height = 5, aspect = 2)
380 | plt.show()
381 |
382 |
383 |
384 | '''
385 | # Trihalomethanes_random
386 | '''
387 | data['Trihalomethanes_random'].describe()
388 |
389 | # Trihalomethanes_random Graph
390 | group_labels = ['Trihalomethanes_random'] # name of the dataset
391 |
392 | # Trihalomethanes_random Ployly Dark Graph
393 | fig = ff.create_distplot([data['Trihalomethanes_random']], group_labels)
394 | fig.show()
395 |
396 | fig = px.box(x = 'Trihalomethanes_random', data_frame = data, template = 'plotly_dark')
397 | fig.update_layout(title='Trihalomethanes_random')
398 | fig.show()
399 |
400 | # Trihalomethane wrt Potability Graph
401 | fig = px.line(x=range(data['Trihalomethanes_random'].count()), y=sorted(data['Trihalomethanes_random']),color=data['Potability'], labels={
402 | 'x': "Count",
403 | 'y': "Trihalomethanes",
404 | 'color':'Potability'
405 |
406 | }, template = 'plotly_dark')
407 | fig.update_layout(title='Trihalomethane wrt Potability')
408 | fig.show()
409 |
410 |
411 |
412 | '''
413 | # Potability
414 | '''
415 | data['Potability'].describe()
416 |
417 | # Potability Plotly Dark Graph
418 | px.histogram(data_frame = data, x = 'Potability', color = 'Potability', marginal = 'box',
419 | template = 'plotly_dark')
420 |
421 | """
422 | # Data Preprocessing
423 | """
424 | from sklearn.preprocessing import StandardScaler
425 | from sklearn.model_selection import train_test_split
426 |
427 | X=data.drop(['Potability'],axis=1)
428 | y=data['Potability']
429 |
430 | # Since The Data Is Not In A Uniform Shape, We Scale The Data Using Standard Scalar
431 | scaler = StandardScaler()
432 | x=scaler.fit_transform(X)
433 |
434 | # split the data to train and test set
435 | x_train,x_test,y_train,y_test = train_test_split(x,y,train_size=0.85,random_state=42)
436 |
437 |
438 | print("Training data shape:-{} labels{} ".format(x_train.shape,y_train.shape))
439 | print("Testing data shape:-{} labels{} ".format(x_test.shape,y_test.shape))
440 |
441 |
442 |
443 |
444 | """
445 | # Modeling
446 | """
447 | # ### Logistic Regression
448 | from sklearn.linear_model import LogisticRegression
449 | log = LogisticRegression(random_state=0).fit(x_train, y_train)
450 | log.score(x_test, y_test)
451 |
452 | # Confusion Matrix testing data Graph
453 | # Confusion matrix
454 | from sklearn.metrics import confusion_matrix
455 | # Make Predictions
456 | pred1=log.predict(np.array(x_test))
457 | plt.title("Confusion Matrix testing data")
458 | sns.heatmap(confusion_matrix(y_test,pred1),annot=True,cbar=False)
459 | plt.legend()
460 | plt.show()
461 |
462 |
463 |
464 | # ### K Nearest Neighbours
465 | from sklearn.neighbors import KNeighborsClassifier
466 |
467 | knn = KNeighborsClassifier(n_neighbors=2)
468 | # Train the model using the training sets
469 | knn.fit(x_train,y_train)
470 |
471 | #Predict Output
472 | predicted= knn.predict(x_test) # 0:Overcast, 2:Mild
473 |
474 | # onfusion Matrix testing data Graph
475 | # Confusion matrix
476 | from sklearn.metrics import confusion_matrix
477 | # Make Predictions
478 | pred1=knn.predict(np.array(x_test))
479 | plt.title("Confusion Matrix testing data")
480 | sns.heatmap(confusion_matrix(y_test,pred1),annot=True,cbar=False)
481 | plt.legend()
482 | plt.show()
483 |
484 |
485 |
486 | # ### SVM
487 | from sklearn import svm
488 | from sklearn.metrics import accuracy_score
489 |
490 | svmc = svm.SVC()
491 | svmc.fit(x_train, y_train)
492 |
493 | y_pred = svmc.predict(x_test)
494 | print(accuracy_score(y_test,y_pred))
495 |
496 | # Confusion Matrix testing data Graph
497 | # Confusion matrix
498 | from sklearn.metrics import confusion_matrix
499 | # Make Predictions
500 | pred1=svmc.predict(np.array(x_test))
501 | plt.title("Confusion Matrix testing data")
502 | sns.heatmap(confusion_matrix(y_test,pred1),annot=True,cbar=False)
503 | plt.legend()
504 | plt.show()
505 |
506 |
507 |
508 | # ### Decision Tree
509 | from sklearn import tree
510 | from sklearn.metrics import accuracy_score
511 |
512 | tre = tree.DecisionTreeClassifier()
513 | tre = tre.fit(x_train, y_train)
514 |
515 | y_pred = tre.predict(x_test)
516 | print(accuracy_score(y_test,y_pred))
517 |
518 | # Confusion Matrix testing data Graph
519 | # Confusion matrix
520 | from sklearn.metrics import confusion_matrix
521 | # Make Predictions
522 | pred1=tre.predict(np.array(x_test))
523 | plt.title("Confusion Matrix testing data")
524 | sns.heatmap(confusion_matrix(y_test,pred1),annot=True,cbar=False)
525 | plt.legend()
526 | plt.show()
527 |
528 |
529 |
530 | # ### Random Forest
531 | from sklearn.ensemble import RandomForestClassifier
532 | from sklearn.metrics import accuracy_score
533 |
534 | # create the model
535 | model_rf = RandomForestClassifier(n_estimators=500, oob_score=True, random_state=100)
536 |
537 | # fitting the model
538 | model_rf=model_rf.fit(x_train, y_train)
539 |
540 | y_pred = model_rf.predict(x_test)
541 | print(accuracy_score(y_test,y_pred))
542 |
543 | # Confusion Matrix testing data Graph
544 | # Confusion matrix
545 | from sklearn.metrics import confusion_matrix
546 | # Make Predictions
547 | pred1=model_rf.predict(np.array(x_test))
548 | plt.title("Confusion Matrix testing data")
549 | sns.heatmap(confusion_matrix(y_test,pred1),annot=True,cbar=False)
550 | plt.legend()
551 | plt.show()
552 |
553 |
554 |
555 | # ### XG Boost
556 | from xgboost import XGBClassifier
557 | from sklearn.metrics import r2_score
558 |
559 | xgb = XGBClassifier(colsample_bylevel= 0.9,
560 | colsample_bytree = 0.8,
561 | gamma=0.99,
562 | max_depth= 5,
563 | min_child_weight= 1,
564 | n_estimators= 8,
565 | nthread= 5,
566 | random_state= 0,
567 | )
568 | xgb.fit(x_train,y_train)
569 |
570 | print('Accuracy Of XGBoost Classifier On Training Set: {:.2f}'
571 | .format(xgb.score(x_train, y_train)))
572 | print('Accuracy Of XGBoost Classifier On Test Set: {:.2f}'
573 | .format(xgb.score(x_test, y_test)))
574 |
575 | # Test Confusion Matrix Graph
576 | from sklearn.metrics import confusion_matrix
577 |
578 | conf_matrix = confusion_matrix(y_true=y_test, y_pred=y_pred)
579 | plt.figure(figsize = (15, 8))
580 | sns.set(font_scale=1.4) # for label size
581 | sns.heatmap(conf_matrix, annot=True, annot_kws={"size": 16},cbar=False, linewidths = 1) # font size
582 | plt.title("Test Confusion Matrix")
583 | plt.xlabel("Predicted class")
584 | plt.ylabel("Actual class")
585 | plt.savefig('conf_test.png')
586 | plt.show()
587 |
588 |
589 |
590 | # ### SVM Tuned
591 | from sklearn.svm import SVC
592 | from sklearn.model_selection import GridSearchCV
593 | svc=SVC()
594 | param_grid={'C':[1.2,1.5,2.2,3.5,3.2,4.1],'kernel':['linear', 'poly', 'rbf', 'sigmoid'],'degree':[1,2,4,8,10],'gamma':['scale','auto']}
595 | gridsearch=GridSearchCV(svc,param_grid=param_grid,n_jobs=-1,verbose=4,cv=3)
596 | gridsearch.fit(x_train,y_train)
597 |
598 | # Test Confusion Matrix Graph
599 | y_pred=gridsearch.predict(x_test)
600 | from sklearn.metrics import confusion_matrix
601 |
602 | conf_matrix = confusion_matrix(y_true=y_test, y_pred=y_pred)
603 | plt.figure(figsize = (15, 8))
604 | sns.set(font_scale=1.4) # for label size
605 | sns.heatmap(conf_matrix, annot=True, annot_kws={"size": 16},cbar=False, linewidths = 1) # font size
606 | plt.title("Test Confusion Matrix")
607 | plt.xlabel("Predicted class")
608 | plt.ylabel("Actual class")
609 | plt.savefig('conf_test.png')
610 | plt.show()
--------------------------------------------------------------------------------