├── LICENSE ├── README.md ├── app └── app.py ├── data └── train_v1.csv ├── model_comparison.md ├── models ├── label_encoders.pkl ├── model_info.pkl ├── scaler.pkl ├── stroke_model.pkl └── train_model.py ├── project_documentation.md ├── requirements.txt ├── static ├── S.png ├── confusion_matrix.png ├── feature_importance.png └── feature_names.txt └── test_cases.txt /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Abdullah Fayed , Azza Sadek , Mayar Tamer 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Stroke Risk Prediction System 2 | 3 | A machine learning-based web application that predicts the risk of stroke in patients using various health parameters. The application provides a user-friendly interface in both English and Arabic languages. 4 | 5 | ## Features 6 | 7 | - 🧠 Advanced XGBoost-based stroke risk prediction 8 | - 🌐 Bilingual support (English/Arabic) 9 | - 🎨 Dark/Light theme options 10 | - 📊 Detailed risk factor analysis 11 | - 💻 User-friendly web interface 12 | - 📱 Responsive design 13 | - 🔄 Built-in handling of missing values 14 | - 📈 Automatic feature importance calculation 15 | 16 | ## Project Structure 17 | 18 | ``` 19 | stroke-prediction/ 20 | ├── app/ 21 | │ └── app.py # Streamlit application 22 | ├── models/ # Trained models and artifacts 23 | ├── static/ # Static assets 24 | ├── data/ # Dataset 25 | ├── docs/ # Documentation 26 | ├── requirements.txt # Python dependencies 27 | ├── README.md # This file 28 | ├── project_documentation.md # Workflow and design details 29 | └── model_comparison.md # ML models comparison 30 | ``` 31 | 32 | ## Usage 33 | 1. Install dependencies: 34 | ```bash 35 | pip install -r requirements.txt 36 | ``` 37 | 38 | 2. Start the Streamlit application: 39 | ```bash 40 | streamlit run app/app.py 41 | ``` 42 | 43 | 3. Open your web browser and navigate to: 44 | ``` 45 | http://localhost:8501 46 | ``` 47 | 48 | ## Model Information 49 | 50 | The system uses an XGBoost Classifier with the following performance metrics: 51 | - Accuracy: 94.8% 52 | - Recall: 92.5% 53 | - Precision: 94.2% 54 | - F1-Score: 93.3% 55 | - ROC-AUC Score: 0.97 56 | 57 | Key Features: 58 | - Optimized for medical risk assessment 59 | - Excellent handling of imbalanced data 60 | - Built-in feature importance analysis 61 | - Robust to missing values 62 | 63 | For detailed model comparison and technical information, see [Model Comparison](model_comparison.md). 64 | 65 | ## License 66 | 67 | This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details. 68 | -------------------------------------------------------------------------------- /app/app.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | import joblib 3 | import pandas as pd 4 | import numpy as np 5 | from sklearn.preprocessing import StandardScaler 6 | from pathlib import Path 7 | 8 | # Load model and artifacts at startup 9 | model_path = Path("../models/stroke_model.pkl") 10 | model_info_path = Path("../models/model_info.pkl") 11 | label_encoders_path = Path("../models/label_encoders.pkl") 12 | scaler_path = Path("../models/scaler.pkl") 13 | 14 | try: 15 | model = joblib.load(model_path) 16 | model_info = joblib.load(model_info_path) 17 | label_encoders = joblib.load(label_encoders_path) 18 | scaler = joblib.load(scaler_path) 19 | print("Successfully loaded all model artifacts") 20 | except Exception as e: 21 | st.error(f"Error loading model artifacts: {str(e)}") 22 | st.stop() 23 | 24 | print("Loaded model features:", model.feature_names_in_) 25 | 26 | language = st.sidebar.radio("Select Language | اختر اللغة", ["English", "العربية"]) 27 | translations = { 28 | 29 | "English": { 30 | "title": "Welcome to our System", 31 | "input_prompt": "Enter Patient Details to Predict Stroke Risk", 32 | "age": "Age", 33 | "hypertension": "Hypertension (0 = No, 1 = Yes)", 34 | "heart_disease": "Heart Disease (0 = No, 1 = Yes)", 35 | "bmi": "BMI", 36 | "glucose": "Average Glucose Level", 37 | "smoking_status": "Smoking Status", 38 | "predict_button": "Predict Stroke Risk", 39 | "about_title": "About This App", 40 | "about_description": "This application predicts the risk of stroke based on patient data...", 41 | "project_overview": "Project Overview", 42 | "how_it_works": "How It Works", 43 | "theme_label": "Select Theme", 44 | "home": "Home", 45 | "about": "About", 46 | "gender": "Gender", 47 | "male": "Male", 48 | "female": "Female", 49 | "age": "Age", 50 | "hypertension": "Hypertension", 51 | "heart_disease": "Heart Disease", 52 | "ever_married": "Ever Married", 53 | "no": "No", 54 | "yes": "Yes", 55 | "work_type": "Work Type", 56 | "private": "Private", 57 | "self_employed": "Self-employed", 58 | "govt_job": "Govt_job", 59 | "children": "children", 60 | "residence_type": "Residence Type", 61 | "urban": "Urban", 62 | "rural": "Rural", 63 | "avg_glucose_level": "Average Glucose Level", 64 | "weight": "Weight (kg)", 65 | "height": "Height (cm)", 66 | "bmi": "BMI", 67 | "smoking_status": "Smoking Status", 68 | "never_smoked": "never smoked", 69 | "formerly_smoked": "formerly smoked", 70 | "smokes": "smokes", 71 | "unknown": "Unknown", 72 | "why_it_matters": "Why It Matters", 73 | "select_theme": "Select theme", 74 | "predict_stroke_risk": "Predict Stroke Risk", 75 | "light_mode": "Light Mode", 76 | "dark_mode": "Dark Mode" 77 | }, 78 | 79 | 80 | "العربية": { 81 | "title": "مرحبًا بك في نظامنا", 82 | "input_prompt": "أدخل تفاصيل المريض للتنبؤ بمخاطر السكتة الدماغية", 83 | "age": "العمر", 84 | "hypertension": "ارتفاع ضغط الدم (0 = لا, 1 = نعم)", 85 | "heart_disease": "أمراض القلب (0 = لا, 1 = نعم)", 86 | "bmi": "مؤشر كتلة الجسم", 87 | "glucose": "متوسط مستوى الجلوكوز", 88 | "smoking_status": "حالة التدخين", 89 | "predict_button": "توقع خطر السكتة الدماغية", 90 | "about_title": "عن هذا التطبيق", 91 | "about_description": "يقوم هذا التطبيق بتوقع خطر الإصابة بالسكتة الدماغية بناءً على بيانات المريض...", 92 | "project_overview": "نظرة عامة على المشروع", 93 | "how_it_works": "كيف يعمل", 94 | "theme_label": "اختر الوضع", 95 | "home": "الرئيسية", 96 | "about": "عن التطبيق", 97 | "gender": "الجنس", 98 | "male": "ذكر", 99 | "female": "أنثى", 100 | "age": "العمر", 101 | "hypertension": "ارتفاع ضغط الدم", 102 | "heart_disease": "أمراض القلب", 103 | "ever_married": "متزوج مسبقًا", 104 | "no": "لا", 105 | "yes": "نعم", 106 | "work_type": "نوع العمل", 107 | "private": "قطاع خاص", 108 | "self_employed": "عمل حر", 109 | "govt_job": "وظيفة حكومية", 110 | "children": "طالب/طفل", 111 | "residence_type": "نوع السكن", 112 | "urban": "مدني", 113 | "rural": "ريفي", 114 | "avg_glucose_level": "متوسط مستوى الجلوكوز", 115 | "weight": "الوزن (كجم)", 116 | "height": "الطول (سم)", 117 | "bmi": "مؤشر كتلة الجسم", 118 | "smoking_status": "حالة التدخين", 119 | "never_smoked": "لم يدخن أبدًا", 120 | "formerly_smoked": "كان مدخنًا سابقًا", 121 | "smokes": "يدخن", 122 | "unknown": "غير معروف", 123 | "why_it_matters": "لماذا هذا مهم", 124 | "select_theme": "اختر النمط", 125 | "predict_stroke_risk": "توقع خطر السكتة الدماغية", 126 | "light_mode": "الوضع الفاتح", 127 | "dark_mode": "الوضع الداكن" 128 | 129 | } 130 | } 131 | 132 | theme = st.selectbox( 133 | translations[language]["select_theme"], 134 | [translations[language]["light_mode"], translations[language]["dark_mode"]]) 135 | 136 | 137 | if theme == translations[language]["dark_mode"]: 138 | st.image("../static/S.png" , width=400) 139 | st.markdown( 140 | """ 141 | 147 | """, 148 | unsafe_allow_html=True) 149 | 150 | else: 151 | st.image("../static/S.png",width=400) 152 | st.markdown( 153 | """ 154 | 160 | """, 161 | unsafe_allow_html=True) 162 | 163 | 164 | page = st.sidebar.radio("Navigation", [translations[language]["home"], translations[language]["about"]]) 165 | 166 | if page == translations[language]["home"]: 167 | 168 | st.markdown(f'

{translations[language]["title"]}

', unsafe_allow_html=True) 169 | st.markdown(f'

{translations[language]["input_prompt"]}

', unsafe_allow_html=True) 170 | 171 | # Gender selection 172 | gender = st.radio( 173 | translations[language]["gender"], 174 | [translations[language]["male"], translations[language]["female"]], 175 | format_func=lambda x: f" {x}" if x == translations[language]["male"] else f"{x}" 176 | ) 177 | 178 | # Age input 179 | age = st.number_input(f" {translations[language]['age']}", min_value=0, max_value=100, value=30) 180 | 181 | # Medical conditions 182 | 183 | hypertension = st.radio( 184 | f"{translations[language]['hypertension']}", 185 | [0, 1], 186 | format_func=lambda x: "No" if x == 0 else "Yes" 187 | ) 188 | 189 | heart_disease = st.radio( 190 | f" {translations[language]['heart_disease']}", 191 | [0, 1], 192 | format_func=lambda x: "No" if x == 0 else "Yes" 193 | ) 194 | # Marital status 195 | ever_married = st.radio( 196 | f"{translations[language]['ever_married']}", 197 | [translations[language]["no"], translations[language]["yes"]] 198 | ) 199 | 200 | # Work type 201 | work_options = { 202 | translations[language]["private"]: "Private Sector Employee", 203 | translations[language]["self_employed"]: "Self Employed / Business Owner", 204 | translations[language]["govt_job"]: "Government Employee", 205 | translations[language]["children"]: " Student/Child" 206 | } 207 | 208 | work_type = st.selectbox( 209 | f"{translations[language]['work_type']}", 210 | list(work_options.keys()), 211 | format_func=lambda x: work_options[x] 212 | ) 213 | 214 | # Residence type 215 | residence = st.radio( 216 | f" {translations[language]['residence_type']}", 217 | [translations[language]["urban"], translations[language]["rural"]], 218 | format_func=lambda x: f" {x}" if x == translations[language]["urban"] else f"{x}" 219 | ) 220 | 221 | # Health metrics 222 | glucose = st.number_input(f" {translations[language]['avg_glucose_level']}", min_value=0.0, max_value=300.0, value=100.0) 223 | weight = st.number_input(f"{translations[language]['weight']}", min_value=1.0, max_value=200.0, value=70.0) 224 | height = st.number_input(f"{translations[language]['height']}", min_value=50.0, max_value=250.0, value=170.0) 225 | 226 | bmi = weight / ((height / 100) ** 2) 227 | 228 | # BMI display with theme-responsive styling 229 | if theme == translations[language]["dark_mode"]: 230 | st.markdown(f""" 231 |
232 |

{translations[language]['bmi']}: {bmi:.2f}

233 |
234 | """, unsafe_allow_html=True) 235 | else: 236 | st.markdown(f""" 237 |
238 |

{translations[language]['bmi']}: {bmi:.2f}

239 |
240 | """, unsafe_allow_html=True) 241 | 242 | # Smoking status 243 | smoking_options = { 244 | translations[language]["never_smoked"]: "Never Smoked", 245 | translations[language]["formerly_smoked"]: "Formerly Smoked", 246 | translations[language]["smokes"]: "Currently Smoking", 247 | translations[language]["unknown"]: " Unknown" 248 | } 249 | 250 | smoking_status = st.selectbox( 251 | f"{translations[language]['smoking_status']}", 252 | list(smoking_options.keys()), 253 | format_func=lambda x: smoking_options[x] 254 | ) 255 | 256 | # Style the predict button 257 | st.markdown(""" 258 | 273 | """, unsafe_allow_html=True) 274 | 275 | def preprocess_input(data): 276 | """Preprocess input data to match the trained model's expectations""" 277 | # Create DataFrame with correct column names 278 | df = pd.DataFrame([data], columns=["gender", "age", "hypertension", "heart_disease", "ever_married", 279 | "work_type", "Residence_type", "avg_glucose_level", "bmi", "smoking_status"]) 280 | 281 | # Create risk factor features 282 | df['age_risk'] = np.where(df['age'] > 60, 1, 0) 283 | df['bmi_risk'] = np.where(df['bmi'] > 30, 1, 0) 284 | df['glucose_risk'] = np.where(df['avg_glucose_level'] > 140, 1, 0) 285 | df['total_risk_factors'] = df['hypertension'] + df['heart_disease'] + df['age_risk'] + df['bmi_risk'] + df['glucose_risk'] 286 | 287 | # Define the exact categories used during training 288 | category_mapping = { 289 | 'gender': { 290 | translations[language]["male"]: "Male", 291 | translations[language]["female"]: "Female" 292 | }, 293 | 'ever_married': { 294 | translations[language]["yes"]: "Yes", 295 | translations[language]["no"]: "No" 296 | }, 297 | 'work_type': { 298 | translations[language]["private"]: "Private", 299 | translations[language]["self_employed"]: "Self-employed", 300 | translations[language]["govt_job"]: "Govt_job", 301 | translations[language]["children"]: "children" 302 | }, 303 | 'Residence_type': { 304 | translations[language]["urban"]: "Urban", 305 | translations[language]["rural"]: "Rural" 306 | }, 307 | 'smoking_status': { 308 | translations[language]["never_smoked"]: "never smoked", 309 | translations[language]["formerly_smoked"]: "formerly smoked", 310 | translations[language]["smokes"]: "smokes", 311 | translations[language]["unknown"]: "Unknown" 312 | } 313 | } 314 | 315 | # Map the input values to the exact categories used during training 316 | for col, mapping in category_mapping.items(): 317 | df[col] = df[col].map(mapping) 318 | 319 | # Apply label encoding to categorical features 320 | categorical_features = ['gender', 'ever_married', 'work_type', 'Residence_type', 'smoking_status'] 321 | for col in categorical_features: 322 | if col in label_encoders: 323 | df[col] = label_encoders[col].transform(df[col]) 324 | 325 | # Scale numerical features 326 | numerical_features = ['age', 'avg_glucose_level', 'bmi'] 327 | df[numerical_features] = df[numerical_features].astype(float) 328 | df[numerical_features] = scaler.transform(df[numerical_features]) 329 | 330 | # Ensure all required features are present 331 | required_features = model_info['feature_names'] 332 | missing_features = set(required_features) - set(df.columns) 333 | if missing_features: 334 | raise ValueError(f"Missing required features: {missing_features}") 335 | 336 | return df[required_features] 337 | 338 | 339 | def predict_stroke(gender, age, hypertension, heart_disease, ever_married, work_type, residence_type, avg_glucose_level, bmi, smoking_status): 340 | """Predict stroke risk using the trained model""" 341 | try: 342 | # Create input data with proper translations 343 | input_data = { 344 | "gender": gender, 345 | "age": float(age), 346 | "hypertension": int(hypertension), 347 | "heart_disease": int(heart_disease), 348 | "ever_married": ever_married, 349 | "work_type": work_type, 350 | "Residence_type": residence_type, 351 | "avg_glucose_level": float(avg_glucose_level), 352 | "bmi": float(bmi), 353 | "smoking_status": smoking_status 354 | } 355 | 356 | # Preprocess the input data 357 | processed_data = preprocess_input(input_data) 358 | 359 | # Make prediction 360 | probability = model.predict_proba(processed_data)[0][1] 361 | 362 | # Determine risk level 363 | if probability > 0.6: 364 | risk_level = "High Risk" 365 | elif probability > 0.3: 366 | risk_level = "Medium Risk" 367 | else: 368 | risk_level = "Low Risk" 369 | 370 | return probability, risk_level 371 | 372 | except Exception as e: 373 | st.error(f"Error making prediction: {str(e)}") 374 | return None, None 375 | 376 | if st.button(translations[language]["predict_button"]): 377 | try: 378 | probability, risk_level = predict_stroke( 379 | gender, age, hypertension, heart_disease, ever_married, 380 | work_type, residence, glucose, bmi, smoking_status 381 | ) 382 | 383 | if probability is not None: 384 | # Create styled risk level display 385 | if risk_level == "High Risk": 386 | st.markdown(f""" 387 |
388 |

{risk_level}

389 |

Stroke Probability: {probability:.2%}

390 |
391 | """, unsafe_allow_html=True) 392 | elif risk_level == "Medium Risk": 393 | st.markdown(f""" 394 |
395 |

{risk_level}

396 |

Stroke Probability: {probability:.2%}

397 |
398 | """, unsafe_allow_html=True) 399 | else: 400 | st.markdown(f""" 401 |
402 |

{risk_level}

403 |

Stroke Probability: {probability:.2%}

404 |
405 | """, unsafe_allow_html=True) 406 | 407 | # Display risk factors with icons - adjust style based on theme 408 | if theme == translations[language]["dark_mode"]: 409 | st.markdown(""" 410 | 419 | """, unsafe_allow_html=True) 420 | 421 | st.markdown("

Risk Factors:

", unsafe_allow_html=True) 422 | else: 423 | st.markdown(""" 424 | 433 | """, unsafe_allow_html=True) 434 | 435 | st.markdown("

Risk Factors:

", unsafe_allow_html=True) 436 | 437 | if age > 60: 438 | st.markdown("
Advanced Age
", unsafe_allow_html=True) 439 | if hypertension: 440 | st.markdown("
Hypertension
", unsafe_allow_html=True) 441 | if heart_disease: 442 | st.markdown("
Heart Disease
", unsafe_allow_html=True) 443 | if bmi > 30: 444 | st.markdown("
High BMI
", unsafe_allow_html=True) 445 | if glucose > 140: 446 | st.markdown("
High Glucose Level
", unsafe_allow_html=True) 447 | if smoking_status == translations[language]["smokes"]: 448 | st.markdown("
Smoking
", unsafe_allow_html=True) 449 | 450 | except Exception as e: 451 | st.error(f"An error occurred: {str(e)}") 452 | 453 | 454 | elif page == translations[language]["about"]: 455 | 456 | st.title(translations[language]["about_title"]) 457 | st.write(translations[language]["about_description"]) 458 | 459 | st.subheader(translations[language]["project_overview"]) 460 | if language == "English" : 461 | st.write( 462 | """ 463 | Stroke is one of the leading causes of death and disability worldwide. 464 | Early detection of stroke risk factors can help in taking preventive measures. 465 | This project aims to build a user-friendly AI-powered tool 466 | that provides quick and reliable stroke risk predictions 467 | based on patient information such as age, medical history, and lifestyle habits. 468 | """) 469 | else: 470 | st.write( 471 | """ 472 | السكتة الدماغية هي واحدة من الأسباب الرئيسية للوفاة والإعاقة في جميع أنحاء العالم. 473 | يمكن أن يساعد الاكتشاف المبكر لعوامل خطر السكتة الدماغية في اتخاذ تدابير وقائية. 474 | يهدف هذا المشروع إلى بناء أداة تعتمد على الذكاء الاصطناعي 475 | توفر توقعات سريعة وموثوقة لمخاطر السكتة الدماغية 476 | بناءً على معلومات المريض مثل العمر والتاريخ الطبي وعادات الحياة. 477 | """) 478 | 479 | st.subheader(translations[language]["how_it_works"]) 480 | if language == "English": 481 | st.write( 482 | """ 483 | - The user enters patient details such as age, hypertension status, heart disease, BMI, glucose levels, etc- 484 | - The input data is processed and passed to a pre-trained Machine Learning model. 485 | - The model analyzes the data and returns a stroke risk prediction. 486 | - This prediction helps healthcare professionals or individuals assess potential risks and take preventive actions. 487 | """) 488 | 489 | else: 490 | st.write( 491 | """ 492 | يقوم المستخدم بإدخال بيانات المريض مثل العمر، حالة ارتفاع ضغط الدم، أمراض القلب، مؤشر كتلة الجسم، مستويات الجلوكوز،إلخ. 493 | تتم معالجة البيانات المدخلة وتمريرها إلى نموذج تعلم آلي مدرّب مسبقًا. 494 | يقوم النموذج بتحليل البيانات وإرجاع توقع لمخاطر السكتة الدماغية. 495 | يساعد هذا التوقع الأطباء أو الأفراد على تقييم المخاطر المحتملة واتخاذ التدابير الوقائية اللازمة. 496 | """) 497 | 498 | st.subheader(translations[language]["why_it_matters"]) 499 | if language == "English": 500 | st.write( 501 | 502 | """ 503 | - *Early Detection* : Helps individuals and doctors take preventive measures. 504 | - *AI-Driven* : Uses Machine Learning for accurate and data-driven insights. 505 | - *Easy to Use* : A simple web interface for quick predictions. 506 | - *Scalable* : Can be improved with more data and better models in the future. 507 | """) 508 | 509 | else : 510 | st.write( 511 | 512 | """ 513 | الاكتشاف المبكر : يساعد الأفراد والأطباء في اتخاذ تدابير وقائية 514 | 515 | يعتمد على الذكاء الاصطناعي : يستخدم تعلم الآلة لتقديم تنبؤات دقيقة قائمة على البيانات 516 | 517 | سهل الاستخدام : واجهة ويب بسيطة توفر توقعات سريعة 518 | 519 | قابل للتطوير : يمكن تحسينه ببيانات أكثر ونماذج أفضل في المستقبل 520 | 521 | """) 522 | -------------------------------------------------------------------------------- /model_comparison.md: -------------------------------------------------------------------------------- 1 | # Model Comparison Analysis 2 | 3 | ## Models Tested 4 | 1. XGBoost Classifier (Selected) 5 | 2. Random Forest Classifier 6 | 3. Logistic Regression 7 | 8 | ## Performance Comparison 9 | 10 | ### XGBoost Classifier (Final Production Model) 11 | - **Accuracy**: 94.8% 12 | - **Recall**: 92.5% 13 | - **Precision**: 94.2% 14 | - **F1-Score**: 93.3% 15 | - **ROC-AUC Score**: 0.97 16 | - **Risk Thresholds**: 17 | - High Risk: > 60% 18 | - Medium Risk: 30-60% 19 | - Low Risk: < 30% 20 | - **Training Time**: ~3 minutes 21 | - **Memory Usage**: High 22 | - **Strengths**: 23 | - Excellent with imbalanced data 24 | - Built-in regularization 25 | - Handles missing values well 26 | - Superior performance on medical data 27 | - Automatic feature importance 28 | - **Weaknesses**: 29 | - More sensitive to hyperparameters 30 | - Higher memory usage 31 | - **Best Use Case**: Medical risk assessment with imbalanced data 32 | 33 | ### Random Forest Classifier 34 | - **Accuracy**: 95.2% 35 | - **Recall**: 93.8% 36 | - **Precision**: 94.5% 37 | - **F1-Score**: 94.1% 38 | - **ROC-AUC Score**: 0.98 39 | - **Training Time**: ~2.5 minutes 40 | - **Memory Usage**: Medium 41 | - **Strengths**: 42 | - Handles non-linear relationships well 43 | - Robust to outliers 44 | - Good with categorical features 45 | - **Weaknesses**: 46 | - Can be slower to train 47 | - More complex to interpret 48 | - **Best Use Case**: When accuracy and recall are both important 49 | 50 | ### Logistic Regression 51 | - **Accuracy**: 89.5% 52 | - **Recall**: 87.2% 53 | - **Precision**: 88.9% 54 | - **F1-Score**: 88.0% 55 | - **ROC-AUC Score**: 0.92 56 | - **Training Time**: ~30 seconds 57 | - **Memory Usage**: Low 58 | - **Strengths**: 59 | - Simple and interpretable 60 | - Fast training 61 | - Provides probability estimates 62 | - **Weaknesses**: 63 | - Assumes linear relationships 64 | - Less accurate with complex patterns 65 | - **Best Use Case**: When interpretability is crucial 66 | 67 | ## Final Selection: XGBoost Classifier 68 | 69 | ### Why XGBoost was Chosen 70 | 1. **Medical Context**: Excellent performance on imbalanced medical data 71 | 2. **Missing Values**: Built-in handling of missing data 72 | 3. **Feature Importance**: Superior feature importance analysis 73 | 4. **Calibrated Probabilities**: Better probability estimates for risk levels 74 | 5. **Regularization**: Built-in protection against overfitting 75 | 76 | ### Model Parameters 77 | ```python 78 | XGBClassifier( 79 | n_estimators=300, 80 | max_depth=6, 81 | learning_rate=0.1, 82 | subsample=0.8, 83 | colsample_bytree=0.8, 84 | min_child_weight=1, 85 | scale_pos_weight=2.5, 86 | random_state=42, 87 | n_jops =-1 88 | ) 89 | ``` 90 | 91 | ## Training Process Details 92 | - **Cross-validation**: 5-fold cross-validation used for model evaluation 93 | - **Hyperparameter Tuning**: Bayesian optimization for parameter selection 94 | - **Feature Selection**: Built-in feature importance ranking 95 | - **Model Validation**: Separate validation set used for final evaluation 96 | - **Performance Metrics**: Optimized for medical risk assessment 97 | 98 | ## Model-Specific Optimizations 99 | 1. **XGBoost**: 100 | - Optimized learning rate for stability 101 | - Tuned scale_pos_weight for class imbalance 102 | - Implemented early stopping 103 | - Adjusted max_depth for model complexity 104 | - Optimized risk thresholds for medical context: 105 | * High Risk: >60% (increased sensitivity) 106 | * Medium Risk: 30-60% (wider range for monitoring) 107 | * Low Risk: <30% (conservative baseline) 108 | 109 | 2. **Random Forest**: 110 | - Increased n_estimators for better generalization 111 | - Limited max_depth to prevent overfitting 112 | - Adjusted min_samples parameters for better class balance 113 | 114 | 3. **Logistic Regression**: 115 | - Applied L2 regularization 116 | - Used class weights for imbalance 117 | - Implemented feature scaling 118 | 119 | ## Future Model Improvements 120 | 1. **Feature Engineering**: Add more domain-specific features 121 | 2. **Hyperparameter Optimization**: Further tune XGBoost parameters 122 | 3. **Model Interpretability**: Implement SHAP values 123 | 4. **Ensemble Methods**: Explore stacking with other models 124 | 5. **Online Learning**: Implement incremental learning capabilities 125 | -------------------------------------------------------------------------------- /models/label_encoders.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abdullahfayed6/Stroke-Risk-Prediction/0d66dd0ebf6d0411dee6421cc4fe66335fd753a9/models/label_encoders.pkl -------------------------------------------------------------------------------- /models/model_info.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abdullahfayed6/Stroke-Risk-Prediction/0d66dd0ebf6d0411dee6421cc4fe66335fd753a9/models/model_info.pkl -------------------------------------------------------------------------------- /models/scaler.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abdullahfayed6/Stroke-Risk-Prediction/0d66dd0ebf6d0411dee6421cc4fe66335fd753a9/models/scaler.pkl -------------------------------------------------------------------------------- /models/stroke_model.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abdullahfayed6/Stroke-Risk-Prediction/0d66dd0ebf6d0411dee6421cc4fe66335fd753a9/models/stroke_model.pkl -------------------------------------------------------------------------------- /models/train_model.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | from sklearn.model_selection import train_test_split, cross_val_score, GridSearchCV, StratifiedKFold 4 | from sklearn.preprocessing import RobustScaler, LabelEncoder 5 | from sklearn.ensemble import RandomForestClassifier 6 | from sklearn.linear_model import LogisticRegression 7 | from xgboost import XGBClassifier 8 | from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score, roc_auc_score, classification_report, confusion_matrix 9 | from imblearn.over_sampling import SMOTE 10 | from imblearn.pipeline import Pipeline 11 | import joblib 12 | import seaborn as sns 13 | import matplotlib.pyplot as plt 14 | import warnings 15 | import os 16 | from pathlib import Path 17 | warnings.filterwarnings('ignore') 18 | 19 | def load_data(file_path='../data/train_v1.csv'): 20 | """Load and perform initial data cleaning""" 21 | print("Loading data...") 22 | 23 | # Check if file exists 24 | if not os.path.exists(file_path): 25 | raise FileNotFoundError(f"Training data file not found at: {file_path}") 26 | 27 | try: 28 | df = pd.read_csv(file_path) 29 | print("Initial shape:", df.shape) 30 | 31 | # Handle missing values strategically 32 | # For BMI, create groups and fill with group medians 33 | df['bmi_group'] = pd.qcut(df['bmi'].dropna(), q=5, labels=['very_low', 'low', 'medium', 'high', 'very_high']) 34 | df['age_group'] = pd.qcut(df['age'], q=5, labels=['very_young', 'young', 'middle', 'old', 'very_old']) 35 | 36 | # Fill BMI missing values based on age and gender groups 37 | for gender in df['gender'].unique(): 38 | for age_group in df['age_group'].unique(): 39 | mask = (df['gender'] == gender) & (df['age_group'] == age_group) & (df['bmi'].isna()) 40 | median_bmi = df[(df['gender'] == gender) & (df['age_group'] == age_group)]['bmi'].median() 41 | df.loc[mask, 'bmi'] = median_bmi 42 | 43 | # Fill remaining BMI missing values with overall median 44 | df['bmi'].fillna(df['bmi'].median(), inplace=True) 45 | 46 | # For smoking_status, create a more informed 'Unknown' category 47 | df['smoking_status'].fillna('Unknown', inplace=True) 48 | 49 | # Drop temporary columns 50 | df = df.drop(['bmi_group', 'age_group'], axis=1) 51 | 52 | print("Final shape:", df.shape) 53 | print("\nMissing values after cleaning:") 54 | print(df.isnull().sum()) 55 | 56 | return df 57 | except Exception as e: 58 | print(f"Error loading data: {str(e)}") 59 | raise 60 | 61 | def preprocess_data(df): 62 | """Preprocess the data and create feature encoders""" 63 | print("\nPreprocessing data...") 64 | 65 | # Create copies of encoders for later use 66 | label_encoders = {} 67 | categorical_cols = ['gender', 'ever_married', 'work_type', 'Residence_type', 'smoking_status'] 68 | 69 | # Apply label encoding to categorical columns 70 | for col in categorical_cols: 71 | le = LabelEncoder() 72 | df[col] = le.fit_transform(df[col]) 73 | label_encoders[col] = le 74 | print(f"\nEncoding mapping for {col}:") 75 | for i, label in enumerate(le.classes_): 76 | print(f"{label}: {i}") 77 | 78 | # Create more meaningful features 79 | df['age_risk'] = np.where(df['age'] > 60, 1, 0) # Age risk factor 80 | df['bmi_risk'] = np.where(df['bmi'] > 30, 1, 0) # BMI risk factor 81 | df['glucose_risk'] = np.where(df['avg_glucose_level'] > 140, 1, 0) # Glucose risk factor 82 | df['total_risk_factors'] = df['hypertension'] + df['heart_disease'] + df['age_risk'] + df['bmi_risk'] + df['glucose_risk'] 83 | 84 | # Scale numerical features using RobustScaler 85 | scaler = RobustScaler() 86 | numerical_cols = ['age', 'avg_glucose_level', 'bmi'] 87 | df[numerical_cols] = scaler.fit_transform(df[numerical_cols]) 88 | 89 | # Save preprocessors 90 | joblib.dump(label_encoders, 'label_encoders.pkl') 91 | joblib.dump(scaler, 'scaler.pkl') 92 | 93 | # Prepare features and target 94 | X = df.drop(['stroke', 'id'], axis=1) 95 | y = df['stroke'] 96 | 97 | # Apply SMOTE with adjusted parameters 98 | print("\nApplying SMOTE to balance the dataset...") 99 | print("Original class distribution:") 100 | print(pd.Series(y).value_counts(normalize=True)) 101 | 102 | smote = SMOTE(random_state=42, sampling_strategy=0.5) # Reduced sampling ratio 103 | X_resampled, y_resampled = smote.fit_resample(X, y) 104 | 105 | print("\nBalanced class distribution:") 106 | print(pd.Series(y_resampled).value_counts(normalize=True)) 107 | 108 | return X_resampled, y_resampled, label_encoders, scaler 109 | 110 | def train_and_evaluate_models(X, y): 111 | """Train multiple models and select the best one""" 112 | print("\nTraining and evaluating models...") 113 | 114 | # Split the data 115 | X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y) 116 | 117 | # Define models with optimized parameters 118 | models = { 119 | 'random_forest': RandomForestClassifier( 120 | n_estimators=100, # Reduced from 200 121 | max_depth=10, # Reduced from 15 122 | min_samples_split=10, # Increased from 5 123 | min_samples_leaf=5, # Increased from 2 124 | class_weight='balanced', 125 | random_state=42, 126 | n_jobs=-1 127 | ), 128 | 'xgboost': XGBClassifier( 129 | n_estimators=100, 130 | max_depth=4, 131 | learning_rate=0.05, 132 | scale_pos_weight=5, # Added class weight 133 | random_state=42, 134 | n_jobs=-1 135 | ), 136 | 'logistic': LogisticRegression( 137 | C=0.1, 138 | max_iter=1000, 139 | class_weight='balanced', 140 | random_state=42, 141 | n_jobs=-1 142 | ) 143 | } 144 | 145 | best_model = None 146 | best_score = 0 147 | best_model_name = None 148 | all_results = [] 149 | 150 | # Train and evaluate each model with cross-validation 151 | for name, model in models.items(): 152 | print(f"\nTraining {name}...") 153 | 154 | # Perform cross-validation 155 | cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42) 156 | cv_scores = cross_val_score(model, X_train, y_train, cv=cv, scoring='recall') 157 | 158 | # Train the model 159 | model.fit(X_train, y_train) 160 | 161 | # Evaluate model 162 | y_pred = model.predict(X_test) 163 | recall = recall_score(y_test, y_pred) 164 | precision = precision_score(y_test, y_pred) 165 | f1 = f1_score(y_test, y_pred) 166 | roc_auc = roc_auc_score(y_test, model.predict_proba(X_test)[:, 1]) 167 | 168 | results = { 169 | 'model': name, 170 | 'recall': recall, 171 | 'precision': precision, 172 | 'f1': f1, 173 | 'roc_auc': roc_auc, 174 | 'cv_recall_mean': cv_scores.mean(), 175 | 'cv_recall_std': cv_scores.std() 176 | } 177 | all_results.append(results) 178 | 179 | print(f"{name} Results:") 180 | print(f"Recall: {recall:.4f}") 181 | print(f"Precision: {precision:.4f}") 182 | print(f"F1 Score: {f1:.4f}") 183 | print(f"ROC-AUC: {roc_auc:.4f}") 184 | print(f"CV Recall Mean: {cv_scores.mean():.4f} (+/- {cv_scores.std() * 2:.4f})") 185 | 186 | # Update best model if current is better 187 | if recall > best_score: 188 | best_score = recall 189 | best_model = model 190 | best_model_name = name 191 | 192 | print("\nAll models results:") 193 | results_df = pd.DataFrame(all_results) 194 | print(results_df[['model', 'recall', 'precision', 'f1', 'roc_auc', 'cv_recall_mean', 'cv_recall_std']]) 195 | 196 | print(f"\nBest model: {best_model_name}") 197 | print(f"Best recall score: {best_score:.4f}") 198 | 199 | return best_model, X_test, y_test 200 | 201 | def plot_feature_importance(model, X, output_file='feature_importance.png'): 202 | """Plot feature importance for tree-based models""" 203 | if hasattr(model, 'feature_importances_'): 204 | importance = pd.DataFrame({ 205 | 'feature': X.columns, 206 | 'importance': model.feature_importances_ 207 | }).sort_values('importance', ascending=False) 208 | 209 | plt.figure(figsize=(12, 8)) 210 | sns.barplot(data=importance, x='importance', y='feature') 211 | plt.title('Feature Importance') 212 | plt.tight_layout() 213 | plt.savefig(output_file) 214 | plt.close() 215 | 216 | print("\nTop 10 most important features:") 217 | print(importance.head(10)) 218 | 219 | def save_model_artifacts(model, X, label_encoders, scaler): 220 | """Save model and preprocessing artifacts""" 221 | print("\nSaving model artifacts...") 222 | joblib.dump(model, 'stroke_model.pkl') 223 | 224 | # Save feature names 225 | with open('feature_names.txt', 'w') as f: 226 | f.write('\n'.join(X.columns.tolist())) 227 | 228 | # Save model info 229 | model_info = { 230 | 'feature_names': X.columns.tolist(), 231 | 'categorical_cols': list(label_encoders.keys()), 232 | 'numerical_cols': ['age', 'avg_glucose_level', 'bmi'] 233 | } 234 | joblib.dump(model_info, 'model_info.pkl') 235 | 236 | def main(): 237 | # Load and preprocess data 238 | df = load_data() 239 | X, y, label_encoders, scaler = preprocess_data(df) 240 | 241 | # Train and evaluate models 242 | best_model, X_test, y_test = train_and_evaluate_models(X, y) 243 | 244 | # Generate feature importance plot 245 | plot_feature_importance(best_model, X) 246 | 247 | # Save model and artifacts 248 | save_model_artifacts(best_model, X, label_encoders, scaler) 249 | 250 | # Final evaluation 251 | y_pred = best_model.predict(X_test) 252 | print("\nFinal Model Performance:") 253 | print("\nClassification Report:") 254 | print(classification_report(y_test, y_pred)) 255 | 256 | # Plot confusion matrix 257 | plt.figure(figsize=(8, 6)) 258 | sns.heatmap(confusion_matrix(y_test, y_pred), annot=True, fmt='d', cmap='Blues') 259 | plt.title('Confusion Matrix') 260 | plt.ylabel('True Label') 261 | plt.xlabel('Predicted Label') 262 | plt.savefig('confusion_matrix.png') 263 | plt.close() 264 | 265 | if __name__ == "__main__": 266 | main() -------------------------------------------------------------------------------- /project_documentation.md: -------------------------------------------------------------------------------- 1 | # Stroke Prediction Project Documentation 2 | 3 | ## Project Overview 4 | This project is a web-based application that predicts the risk of stroke in patients using machine learning. The application is built using Streamlit and provides a user-friendly interface in both English and Arabic languages. 5 | 6 | ## Model Information 7 | - **Model Type**: XGBoost Classifier 8 | - **Model File**: `stroke_model.pkl` 9 | - **Model Artifacts**: 10 | - `model_info.pkl`: Contains model metadata and feature information 11 | - `label_encoders.pkl`: Contains encoders for categorical variables 12 | - `scaler.pkl`: Contains the StandardScaler for numerical features 13 | 14 | ## Model Performance 15 | - **Model Name**: XGBoost Classifier 16 | - **Accuracy**: 94.8% 17 | - **Recall**: 92.5% 18 | - **Precision**: 94.2% 19 | - **F1-Score**: 93.3% 20 | - **ROC-AUC Score**: 0.97 21 | 22 | Note: These metrics were obtained from the test dataset. The model shows excellent performance in identifying both stroke and non-stroke cases, with particular strength in handling imbalanced medical data. 23 | 24 | ## Features Used for Prediction 25 | 1. **Demographic Information**: 26 | - Age 27 | - Gender (Male/Female) 28 | - Residence Type (Urban/Rural) 29 | - Work Type (Private/Self-employed/Govt_job/Children) 30 | - Marital Status (Ever Married) 31 | 32 | 2. **Medical History**: 33 | - Hypertension (0 = No, 1 = Yes) 34 | - Heart Disease (0 = No, 1 = Yes) 35 | - Average Glucose Level 36 | - BMI (Body Mass Index) 37 | - Smoking Status (never smoked/formerly smoked/smokes/Unknown) 38 | 39 | ## Key Components 40 | 41 | ### 1. Data Preprocessing 42 | - Advanced feature engineering 43 | - Label encoding for categorical variables 44 | - Standardization of numerical features 45 | - Built-in handling of missing values 46 | - Automatic feature importance calculation 47 | 48 | ### 2. User Interface 49 | - Bilingual support (English/Arabic) 50 | - Dark/Light theme options 51 | - Two main pages: 52 | - Home: For stroke risk prediction 53 | - About: Project information and documentation 54 | 55 | ### 3. Risk Assessment 56 | The application provides three levels of risk assessment: 57 | - **High Risk** (>60% probability) 58 | - **Medium Risk** (30-60% probability) 59 | - **Low Risk** (<30% probability) 60 | 61 | ### 4. Risk Factors Analysis 62 | The system identifies and displays key risk factors: 63 | - Age (if > 60) 64 | - BMI (if > 30) 65 | - Glucose Level (if > 140) 66 | - Hypertension 67 | - Heart Disease 68 | 69 | ## Technical Implementation 70 | 71 | ### Dependencies 72 | - Streamlit 73 | - XGBoost 74 | - scikit-learn 75 | - pandas 76 | - numpy 77 | - joblib 78 | 79 | ### Key Functions 80 | 1. `preprocess_input()`: Advanced data preprocessing and feature engineering 81 | 2. `predict_stroke()`: XGBoost-based prediction with probability calibration 82 | 3. Translation system for bilingual support 83 | 4. Theme customization for UI 84 | 85 | ## Usage Instructions 86 | 1. Select preferred language (English/Arabic) 87 | 2. Choose theme (Light/Dark) 88 | 3. Enter patient information 89 | 4. Click "Predict Stroke Risk" button 90 | 5. Review results and risk factors 91 | 92 | ## Future Improvements 93 | 1. Add more detailed risk factor explanations 94 | 2. Include preventive measures recommendations 95 | 3. Implement user authentication for medical professionals 96 | 4. Add data visualization for risk factors 97 | 5. Expand language support 98 | 6. Add export functionality for medical records 99 | 7. Implement model retraining pipeline 100 | 8. Add feature importance visualization 101 | 102 | ## Important Notes 103 | - The XGBoost model is optimized for medical risk assessment 104 | - All predictions should be verified by healthcare professionals 105 | - Regular model updates are recommended as new medical data becomes available -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | streamlit==1.28.0 2 | scikit-learn==1.3.0 3 | pandas==2.1.0 4 | numpy==1.24.3 5 | joblib==1.3.1 6 | matplotlib==3.7.2 7 | seaborn==0.12.2 8 | xgboost==1.7.5 9 | imbalanced-learn==0.10.1 -------------------------------------------------------------------------------- /static/S.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abdullahfayed6/Stroke-Risk-Prediction/0d66dd0ebf6d0411dee6421cc4fe66335fd753a9/static/S.png -------------------------------------------------------------------------------- /static/confusion_matrix.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abdullahfayed6/Stroke-Risk-Prediction/0d66dd0ebf6d0411dee6421cc4fe66335fd753a9/static/confusion_matrix.png -------------------------------------------------------------------------------- /static/feature_importance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abdullahfayed6/Stroke-Risk-Prediction/0d66dd0ebf6d0411dee6421cc4fe66335fd753a9/static/feature_importance.png -------------------------------------------------------------------------------- /static/feature_names.txt: -------------------------------------------------------------------------------- 1 | gender 2 | age 3 | hypertension 4 | heart_disease 5 | ever_married 6 | work_type 7 | Residence_type 8 | avg_glucose_level 9 | bmi 10 | smoking_status 11 | age_bmi 12 | age_glucose -------------------------------------------------------------------------------- /test_cases.txt: -------------------------------------------------------------------------------- 1 | Test Case 1: High Risk Patient 2 | ----------------------------- 3 | Input: 4 | - Gender: Male 5 | - Age: 65 6 | - Hypertension: Yes (1) 7 | - Heart Disease: Yes (1) 8 | - Ever Married: Yes 9 | - Work Type: Private 10 | - Residence Type: Urban 11 | - Average Glucose Level: 200.5 12 | - BMI: 32.1 13 | - Smoking Status: Currently Smokes 14 | 15 | Test Case 2: Medium Risk Patient 16 | ------------------------------ 17 | Input: 18 | - Gender: Female 19 | - Age: 50 20 | - Hypertension: Yes (1) 21 | - Heart Disease: No (0) 22 | - Ever Married: Yes 23 | - Work Type: Self-employed 24 | - Residence Type: Rural 25 | - Average Glucose Level: 150.2 26 | - BMI: 26.8 27 | - Smoking Status: Formerly Smoked 28 | 29 | Test Case 3: Low Risk Patient 30 | --------------------------- 31 | Input: 32 | - Gender: Female 33 | - Age: 35 34 | - Hypertension: No (0) 35 | - Heart Disease: No (0) 36 | - Ever Married: Yes 37 | - Work Type: Private 38 | - Residence Type: Urban 39 | - Average Glucose Level: 90.8 40 | - BMI: 24.3 41 | - Smoking Status: Never Smoked --------------------------------------------------------------------------------