├── LICENSE
├── README.md
├── app
└── app.py
├── data
└── train_v1.csv
├── model_comparison.md
├── models
├── label_encoders.pkl
├── model_info.pkl
├── scaler.pkl
├── stroke_model.pkl
└── train_model.py
├── project_documentation.md
├── requirements.txt
├── static
├── S.png
├── confusion_matrix.png
├── feature_importance.png
└── feature_names.txt
└── test_cases.txt
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 Abdullah Fayed , Azza Sadek , Mayar Tamer
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Stroke Risk Prediction System
2 |
3 | A machine learning-based web application that predicts the risk of stroke in patients using various health parameters. The application provides a user-friendly interface in both English and Arabic languages.
4 |
5 | ## Features
6 |
7 | - 🧠 Advanced XGBoost-based stroke risk prediction
8 | - 🌐 Bilingual support (English/Arabic)
9 | - 🎨 Dark/Light theme options
10 | - 📊 Detailed risk factor analysis
11 | - 💻 User-friendly web interface
12 | - 📱 Responsive design
13 | - 🔄 Built-in handling of missing values
14 | - 📈 Automatic feature importance calculation
15 |
16 | ## Project Structure
17 |
18 | ```
19 | stroke-prediction/
20 | ├── app/
21 | │ └── app.py # Streamlit application
22 | ├── models/ # Trained models and artifacts
23 | ├── static/ # Static assets
24 | ├── data/ # Dataset
25 | ├── docs/ # Documentation
26 | ├── requirements.txt # Python dependencies
27 | ├── README.md # This file
28 | ├── project_documentation.md # Workflow and design details
29 | └── model_comparison.md # ML models comparison
30 | ```
31 |
32 | ## Usage
33 | 1. Install dependencies:
34 | ```bash
35 | pip install -r requirements.txt
36 | ```
37 |
38 | 2. Start the Streamlit application:
39 | ```bash
40 | streamlit run app/app.py
41 | ```
42 |
43 | 3. Open your web browser and navigate to:
44 | ```
45 | http://localhost:8501
46 | ```
47 |
48 | ## Model Information
49 |
50 | The system uses an XGBoost Classifier with the following performance metrics:
51 | - Accuracy: 94.8%
52 | - Recall: 92.5%
53 | - Precision: 94.2%
54 | - F1-Score: 93.3%
55 | - ROC-AUC Score: 0.97
56 |
57 | Key Features:
58 | - Optimized for medical risk assessment
59 | - Excellent handling of imbalanced data
60 | - Built-in feature importance analysis
61 | - Robust to missing values
62 |
63 | For detailed model comparison and technical information, see [Model Comparison](model_comparison.md).
64 |
65 | ## License
66 |
67 | This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
68 |
--------------------------------------------------------------------------------
/app/app.py:
--------------------------------------------------------------------------------
1 | import streamlit as st
2 | import joblib
3 | import pandas as pd
4 | import numpy as np
5 | from sklearn.preprocessing import StandardScaler
6 | from pathlib import Path
7 |
8 | # Load model and artifacts at startup
9 | model_path = Path("../models/stroke_model.pkl")
10 | model_info_path = Path("../models/model_info.pkl")
11 | label_encoders_path = Path("../models/label_encoders.pkl")
12 | scaler_path = Path("../models/scaler.pkl")
13 |
14 | try:
15 | model = joblib.load(model_path)
16 | model_info = joblib.load(model_info_path)
17 | label_encoders = joblib.load(label_encoders_path)
18 | scaler = joblib.load(scaler_path)
19 | print("Successfully loaded all model artifacts")
20 | except Exception as e:
21 | st.error(f"Error loading model artifacts: {str(e)}")
22 | st.stop()
23 |
24 | print("Loaded model features:", model.feature_names_in_)
25 |
26 | language = st.sidebar.radio("Select Language | اختر اللغة", ["English", "العربية"])
27 | translations = {
28 |
29 | "English": {
30 | "title": "Welcome to our System",
31 | "input_prompt": "Enter Patient Details to Predict Stroke Risk",
32 | "age": "Age",
33 | "hypertension": "Hypertension (0 = No, 1 = Yes)",
34 | "heart_disease": "Heart Disease (0 = No, 1 = Yes)",
35 | "bmi": "BMI",
36 | "glucose": "Average Glucose Level",
37 | "smoking_status": "Smoking Status",
38 | "predict_button": "Predict Stroke Risk",
39 | "about_title": "About This App",
40 | "about_description": "This application predicts the risk of stroke based on patient data...",
41 | "project_overview": "Project Overview",
42 | "how_it_works": "How It Works",
43 | "theme_label": "Select Theme",
44 | "home": "Home",
45 | "about": "About",
46 | "gender": "Gender",
47 | "male": "Male",
48 | "female": "Female",
49 | "age": "Age",
50 | "hypertension": "Hypertension",
51 | "heart_disease": "Heart Disease",
52 | "ever_married": "Ever Married",
53 | "no": "No",
54 | "yes": "Yes",
55 | "work_type": "Work Type",
56 | "private": "Private",
57 | "self_employed": "Self-employed",
58 | "govt_job": "Govt_job",
59 | "children": "children",
60 | "residence_type": "Residence Type",
61 | "urban": "Urban",
62 | "rural": "Rural",
63 | "avg_glucose_level": "Average Glucose Level",
64 | "weight": "Weight (kg)",
65 | "height": "Height (cm)",
66 | "bmi": "BMI",
67 | "smoking_status": "Smoking Status",
68 | "never_smoked": "never smoked",
69 | "formerly_smoked": "formerly smoked",
70 | "smokes": "smokes",
71 | "unknown": "Unknown",
72 | "why_it_matters": "Why It Matters",
73 | "select_theme": "Select theme",
74 | "predict_stroke_risk": "Predict Stroke Risk",
75 | "light_mode": "Light Mode",
76 | "dark_mode": "Dark Mode"
77 | },
78 |
79 |
80 | "العربية": {
81 | "title": "مرحبًا بك في نظامنا",
82 | "input_prompt": "أدخل تفاصيل المريض للتنبؤ بمخاطر السكتة الدماغية",
83 | "age": "العمر",
84 | "hypertension": "ارتفاع ضغط الدم (0 = لا, 1 = نعم)",
85 | "heart_disease": "أمراض القلب (0 = لا, 1 = نعم)",
86 | "bmi": "مؤشر كتلة الجسم",
87 | "glucose": "متوسط مستوى الجلوكوز",
88 | "smoking_status": "حالة التدخين",
89 | "predict_button": "توقع خطر السكتة الدماغية",
90 | "about_title": "عن هذا التطبيق",
91 | "about_description": "يقوم هذا التطبيق بتوقع خطر الإصابة بالسكتة الدماغية بناءً على بيانات المريض...",
92 | "project_overview": "نظرة عامة على المشروع",
93 | "how_it_works": "كيف يعمل",
94 | "theme_label": "اختر الوضع",
95 | "home": "الرئيسية",
96 | "about": "عن التطبيق",
97 | "gender": "الجنس",
98 | "male": "ذكر",
99 | "female": "أنثى",
100 | "age": "العمر",
101 | "hypertension": "ارتفاع ضغط الدم",
102 | "heart_disease": "أمراض القلب",
103 | "ever_married": "متزوج مسبقًا",
104 | "no": "لا",
105 | "yes": "نعم",
106 | "work_type": "نوع العمل",
107 | "private": "قطاع خاص",
108 | "self_employed": "عمل حر",
109 | "govt_job": "وظيفة حكومية",
110 | "children": "طالب/طفل",
111 | "residence_type": "نوع السكن",
112 | "urban": "مدني",
113 | "rural": "ريفي",
114 | "avg_glucose_level": "متوسط مستوى الجلوكوز",
115 | "weight": "الوزن (كجم)",
116 | "height": "الطول (سم)",
117 | "bmi": "مؤشر كتلة الجسم",
118 | "smoking_status": "حالة التدخين",
119 | "never_smoked": "لم يدخن أبدًا",
120 | "formerly_smoked": "كان مدخنًا سابقًا",
121 | "smokes": "يدخن",
122 | "unknown": "غير معروف",
123 | "why_it_matters": "لماذا هذا مهم",
124 | "select_theme": "اختر النمط",
125 | "predict_stroke_risk": "توقع خطر السكتة الدماغية",
126 | "light_mode": "الوضع الفاتح",
127 | "dark_mode": "الوضع الداكن"
128 |
129 | }
130 | }
131 |
132 | theme = st.selectbox(
133 | translations[language]["select_theme"],
134 | [translations[language]["light_mode"], translations[language]["dark_mode"]])
135 |
136 |
137 | if theme == translations[language]["dark_mode"]:
138 | st.image("../static/S.png" , width=400)
139 | st.markdown(
140 | """
141 |
147 | """,
148 | unsafe_allow_html=True)
149 |
150 | else:
151 | st.image("../static/S.png",width=400)
152 | st.markdown(
153 | """
154 |
160 | """,
161 | unsafe_allow_html=True)
162 |
163 |
164 | page = st.sidebar.radio("Navigation", [translations[language]["home"], translations[language]["about"]])
165 |
166 | if page == translations[language]["home"]:
167 |
168 | st.markdown(f'
{translations[language]["title"]}
', unsafe_allow_html=True)
169 | st.markdown(f'{translations[language]["input_prompt"]}
', unsafe_allow_html=True)
170 |
171 | # Gender selection
172 | gender = st.radio(
173 | translations[language]["gender"],
174 | [translations[language]["male"], translations[language]["female"]],
175 | format_func=lambda x: f" {x}" if x == translations[language]["male"] else f"{x}"
176 | )
177 |
178 | # Age input
179 | age = st.number_input(f" {translations[language]['age']}", min_value=0, max_value=100, value=30)
180 |
181 | # Medical conditions
182 |
183 | hypertension = st.radio(
184 | f"{translations[language]['hypertension']}",
185 | [0, 1],
186 | format_func=lambda x: "No" if x == 0 else "Yes"
187 | )
188 |
189 | heart_disease = st.radio(
190 | f" {translations[language]['heart_disease']}",
191 | [0, 1],
192 | format_func=lambda x: "No" if x == 0 else "Yes"
193 | )
194 | # Marital status
195 | ever_married = st.radio(
196 | f"{translations[language]['ever_married']}",
197 | [translations[language]["no"], translations[language]["yes"]]
198 | )
199 |
200 | # Work type
201 | work_options = {
202 | translations[language]["private"]: "Private Sector Employee",
203 | translations[language]["self_employed"]: "Self Employed / Business Owner",
204 | translations[language]["govt_job"]: "Government Employee",
205 | translations[language]["children"]: " Student/Child"
206 | }
207 |
208 | work_type = st.selectbox(
209 | f"{translations[language]['work_type']}",
210 | list(work_options.keys()),
211 | format_func=lambda x: work_options[x]
212 | )
213 |
214 | # Residence type
215 | residence = st.radio(
216 | f" {translations[language]['residence_type']}",
217 | [translations[language]["urban"], translations[language]["rural"]],
218 | format_func=lambda x: f" {x}" if x == translations[language]["urban"] else f"{x}"
219 | )
220 |
221 | # Health metrics
222 | glucose = st.number_input(f" {translations[language]['avg_glucose_level']}", min_value=0.0, max_value=300.0, value=100.0)
223 | weight = st.number_input(f"{translations[language]['weight']}", min_value=1.0, max_value=200.0, value=70.0)
224 | height = st.number_input(f"{translations[language]['height']}", min_value=50.0, max_value=250.0, value=170.0)
225 |
226 | bmi = weight / ((height / 100) ** 2)
227 |
228 | # BMI display with theme-responsive styling
229 | if theme == translations[language]["dark_mode"]:
230 | st.markdown(f"""
231 |
232 |
{translations[language]['bmi']}: {bmi:.2f}
233 |
234 | """, unsafe_allow_html=True)
235 | else:
236 | st.markdown(f"""
237 |
238 |
{translations[language]['bmi']}: {bmi:.2f}
239 |
240 | """, unsafe_allow_html=True)
241 |
242 | # Smoking status
243 | smoking_options = {
244 | translations[language]["never_smoked"]: "Never Smoked",
245 | translations[language]["formerly_smoked"]: "Formerly Smoked",
246 | translations[language]["smokes"]: "Currently Smoking",
247 | translations[language]["unknown"]: " Unknown"
248 | }
249 |
250 | smoking_status = st.selectbox(
251 | f"{translations[language]['smoking_status']}",
252 | list(smoking_options.keys()),
253 | format_func=lambda x: smoking_options[x]
254 | )
255 |
256 | # Style the predict button
257 | st.markdown("""
258 |
273 | """, unsafe_allow_html=True)
274 |
275 | def preprocess_input(data):
276 | """Preprocess input data to match the trained model's expectations"""
277 | # Create DataFrame with correct column names
278 | df = pd.DataFrame([data], columns=["gender", "age", "hypertension", "heart_disease", "ever_married",
279 | "work_type", "Residence_type", "avg_glucose_level", "bmi", "smoking_status"])
280 |
281 | # Create risk factor features
282 | df['age_risk'] = np.where(df['age'] > 60, 1, 0)
283 | df['bmi_risk'] = np.where(df['bmi'] > 30, 1, 0)
284 | df['glucose_risk'] = np.where(df['avg_glucose_level'] > 140, 1, 0)
285 | df['total_risk_factors'] = df['hypertension'] + df['heart_disease'] + df['age_risk'] + df['bmi_risk'] + df['glucose_risk']
286 |
287 | # Define the exact categories used during training
288 | category_mapping = {
289 | 'gender': {
290 | translations[language]["male"]: "Male",
291 | translations[language]["female"]: "Female"
292 | },
293 | 'ever_married': {
294 | translations[language]["yes"]: "Yes",
295 | translations[language]["no"]: "No"
296 | },
297 | 'work_type': {
298 | translations[language]["private"]: "Private",
299 | translations[language]["self_employed"]: "Self-employed",
300 | translations[language]["govt_job"]: "Govt_job",
301 | translations[language]["children"]: "children"
302 | },
303 | 'Residence_type': {
304 | translations[language]["urban"]: "Urban",
305 | translations[language]["rural"]: "Rural"
306 | },
307 | 'smoking_status': {
308 | translations[language]["never_smoked"]: "never smoked",
309 | translations[language]["formerly_smoked"]: "formerly smoked",
310 | translations[language]["smokes"]: "smokes",
311 | translations[language]["unknown"]: "Unknown"
312 | }
313 | }
314 |
315 | # Map the input values to the exact categories used during training
316 | for col, mapping in category_mapping.items():
317 | df[col] = df[col].map(mapping)
318 |
319 | # Apply label encoding to categorical features
320 | categorical_features = ['gender', 'ever_married', 'work_type', 'Residence_type', 'smoking_status']
321 | for col in categorical_features:
322 | if col in label_encoders:
323 | df[col] = label_encoders[col].transform(df[col])
324 |
325 | # Scale numerical features
326 | numerical_features = ['age', 'avg_glucose_level', 'bmi']
327 | df[numerical_features] = df[numerical_features].astype(float)
328 | df[numerical_features] = scaler.transform(df[numerical_features])
329 |
330 | # Ensure all required features are present
331 | required_features = model_info['feature_names']
332 | missing_features = set(required_features) - set(df.columns)
333 | if missing_features:
334 | raise ValueError(f"Missing required features: {missing_features}")
335 |
336 | return df[required_features]
337 |
338 |
339 | def predict_stroke(gender, age, hypertension, heart_disease, ever_married, work_type, residence_type, avg_glucose_level, bmi, smoking_status):
340 | """Predict stroke risk using the trained model"""
341 | try:
342 | # Create input data with proper translations
343 | input_data = {
344 | "gender": gender,
345 | "age": float(age),
346 | "hypertension": int(hypertension),
347 | "heart_disease": int(heart_disease),
348 | "ever_married": ever_married,
349 | "work_type": work_type,
350 | "Residence_type": residence_type,
351 | "avg_glucose_level": float(avg_glucose_level),
352 | "bmi": float(bmi),
353 | "smoking_status": smoking_status
354 | }
355 |
356 | # Preprocess the input data
357 | processed_data = preprocess_input(input_data)
358 |
359 | # Make prediction
360 | probability = model.predict_proba(processed_data)[0][1]
361 |
362 | # Determine risk level
363 | if probability > 0.6:
364 | risk_level = "High Risk"
365 | elif probability > 0.3:
366 | risk_level = "Medium Risk"
367 | else:
368 | risk_level = "Low Risk"
369 |
370 | return probability, risk_level
371 |
372 | except Exception as e:
373 | st.error(f"Error making prediction: {str(e)}")
374 | return None, None
375 |
376 | if st.button(translations[language]["predict_button"]):
377 | try:
378 | probability, risk_level = predict_stroke(
379 | gender, age, hypertension, heart_disease, ever_married,
380 | work_type, residence, glucose, bmi, smoking_status
381 | )
382 |
383 | if probability is not None:
384 | # Create styled risk level display
385 | if risk_level == "High Risk":
386 | st.markdown(f"""
387 |
388 |
{risk_level}
389 | Stroke Probability: {probability:.2%}
390 |
391 | """, unsafe_allow_html=True)
392 | elif risk_level == "Medium Risk":
393 | st.markdown(f"""
394 |
395 |
{risk_level}
396 | Stroke Probability: {probability:.2%}
397 |
398 | """, unsafe_allow_html=True)
399 | else:
400 | st.markdown(f"""
401 |
402 |
{risk_level}
403 | Stroke Probability: {probability:.2%}
404 |
405 | """, unsafe_allow_html=True)
406 |
407 | # Display risk factors with icons - adjust style based on theme
408 | if theme == translations[language]["dark_mode"]:
409 | st.markdown("""
410 |
419 | """, unsafe_allow_html=True)
420 |
421 | st.markdown("Risk Factors:
", unsafe_allow_html=True)
422 | else:
423 | st.markdown("""
424 |
433 | """, unsafe_allow_html=True)
434 |
435 | st.markdown("Risk Factors:
", unsafe_allow_html=True)
436 |
437 | if age > 60:
438 | st.markdown("Advanced Age
", unsafe_allow_html=True)
439 | if hypertension:
440 | st.markdown("Hypertension
", unsafe_allow_html=True)
441 | if heart_disease:
442 | st.markdown("Heart Disease
", unsafe_allow_html=True)
443 | if bmi > 30:
444 | st.markdown("High BMI
", unsafe_allow_html=True)
445 | if glucose > 140:
446 | st.markdown("High Glucose Level
", unsafe_allow_html=True)
447 | if smoking_status == translations[language]["smokes"]:
448 | st.markdown(" Smoking
", unsafe_allow_html=True)
449 |
450 | except Exception as e:
451 | st.error(f"An error occurred: {str(e)}")
452 |
453 |
454 | elif page == translations[language]["about"]:
455 |
456 | st.title(translations[language]["about_title"])
457 | st.write(translations[language]["about_description"])
458 |
459 | st.subheader(translations[language]["project_overview"])
460 | if language == "English" :
461 | st.write(
462 | """
463 | Stroke is one of the leading causes of death and disability worldwide.
464 | Early detection of stroke risk factors can help in taking preventive measures.
465 | This project aims to build a user-friendly AI-powered tool
466 | that provides quick and reliable stroke risk predictions
467 | based on patient information such as age, medical history, and lifestyle habits.
468 | """)
469 | else:
470 | st.write(
471 | """
472 | السكتة الدماغية هي واحدة من الأسباب الرئيسية للوفاة والإعاقة في جميع أنحاء العالم.
473 | يمكن أن يساعد الاكتشاف المبكر لعوامل خطر السكتة الدماغية في اتخاذ تدابير وقائية.
474 | يهدف هذا المشروع إلى بناء أداة تعتمد على الذكاء الاصطناعي
475 | توفر توقعات سريعة وموثوقة لمخاطر السكتة الدماغية
476 | بناءً على معلومات المريض مثل العمر والتاريخ الطبي وعادات الحياة.
477 | """)
478 |
479 | st.subheader(translations[language]["how_it_works"])
480 | if language == "English":
481 | st.write(
482 | """
483 | - The user enters patient details such as age, hypertension status, heart disease, BMI, glucose levels, etc-
484 | - The input data is processed and passed to a pre-trained Machine Learning model.
485 | - The model analyzes the data and returns a stroke risk prediction.
486 | - This prediction helps healthcare professionals or individuals assess potential risks and take preventive actions.
487 | """)
488 |
489 | else:
490 | st.write(
491 | """
492 | يقوم المستخدم بإدخال بيانات المريض مثل العمر، حالة ارتفاع ضغط الدم، أمراض القلب، مؤشر كتلة الجسم، مستويات الجلوكوز،إلخ.
493 | تتم معالجة البيانات المدخلة وتمريرها إلى نموذج تعلم آلي مدرّب مسبقًا.
494 | يقوم النموذج بتحليل البيانات وإرجاع توقع لمخاطر السكتة الدماغية.
495 | يساعد هذا التوقع الأطباء أو الأفراد على تقييم المخاطر المحتملة واتخاذ التدابير الوقائية اللازمة.
496 | """)
497 |
498 | st.subheader(translations[language]["why_it_matters"])
499 | if language == "English":
500 | st.write(
501 |
502 | """
503 | - *Early Detection* : Helps individuals and doctors take preventive measures.
504 | - *AI-Driven* : Uses Machine Learning for accurate and data-driven insights.
505 | - *Easy to Use* : A simple web interface for quick predictions.
506 | - *Scalable* : Can be improved with more data and better models in the future.
507 | """)
508 |
509 | else :
510 | st.write(
511 |
512 | """
513 | الاكتشاف المبكر : يساعد الأفراد والأطباء في اتخاذ تدابير وقائية
514 |
515 | يعتمد على الذكاء الاصطناعي : يستخدم تعلم الآلة لتقديم تنبؤات دقيقة قائمة على البيانات
516 |
517 | سهل الاستخدام : واجهة ويب بسيطة توفر توقعات سريعة
518 |
519 | قابل للتطوير : يمكن تحسينه ببيانات أكثر ونماذج أفضل في المستقبل
520 |
521 | """)
522 |
--------------------------------------------------------------------------------
/model_comparison.md:
--------------------------------------------------------------------------------
1 | # Model Comparison Analysis
2 |
3 | ## Models Tested
4 | 1. XGBoost Classifier (Selected)
5 | 2. Random Forest Classifier
6 | 3. Logistic Regression
7 |
8 | ## Performance Comparison
9 |
10 | ### XGBoost Classifier (Final Production Model)
11 | - **Accuracy**: 94.8%
12 | - **Recall**: 92.5%
13 | - **Precision**: 94.2%
14 | - **F1-Score**: 93.3%
15 | - **ROC-AUC Score**: 0.97
16 | - **Risk Thresholds**:
17 | - High Risk: > 60%
18 | - Medium Risk: 30-60%
19 | - Low Risk: < 30%
20 | - **Training Time**: ~3 minutes
21 | - **Memory Usage**: High
22 | - **Strengths**:
23 | - Excellent with imbalanced data
24 | - Built-in regularization
25 | - Handles missing values well
26 | - Superior performance on medical data
27 | - Automatic feature importance
28 | - **Weaknesses**:
29 | - More sensitive to hyperparameters
30 | - Higher memory usage
31 | - **Best Use Case**: Medical risk assessment with imbalanced data
32 |
33 | ### Random Forest Classifier
34 | - **Accuracy**: 95.2%
35 | - **Recall**: 93.8%
36 | - **Precision**: 94.5%
37 | - **F1-Score**: 94.1%
38 | - **ROC-AUC Score**: 0.98
39 | - **Training Time**: ~2.5 minutes
40 | - **Memory Usage**: Medium
41 | - **Strengths**:
42 | - Handles non-linear relationships well
43 | - Robust to outliers
44 | - Good with categorical features
45 | - **Weaknesses**:
46 | - Can be slower to train
47 | - More complex to interpret
48 | - **Best Use Case**: When accuracy and recall are both important
49 |
50 | ### Logistic Regression
51 | - **Accuracy**: 89.5%
52 | - **Recall**: 87.2%
53 | - **Precision**: 88.9%
54 | - **F1-Score**: 88.0%
55 | - **ROC-AUC Score**: 0.92
56 | - **Training Time**: ~30 seconds
57 | - **Memory Usage**: Low
58 | - **Strengths**:
59 | - Simple and interpretable
60 | - Fast training
61 | - Provides probability estimates
62 | - **Weaknesses**:
63 | - Assumes linear relationships
64 | - Less accurate with complex patterns
65 | - **Best Use Case**: When interpretability is crucial
66 |
67 | ## Final Selection: XGBoost Classifier
68 |
69 | ### Why XGBoost was Chosen
70 | 1. **Medical Context**: Excellent performance on imbalanced medical data
71 | 2. **Missing Values**: Built-in handling of missing data
72 | 3. **Feature Importance**: Superior feature importance analysis
73 | 4. **Calibrated Probabilities**: Better probability estimates for risk levels
74 | 5. **Regularization**: Built-in protection against overfitting
75 |
76 | ### Model Parameters
77 | ```python
78 | XGBClassifier(
79 | n_estimators=300,
80 | max_depth=6,
81 | learning_rate=0.1,
82 | subsample=0.8,
83 | colsample_bytree=0.8,
84 | min_child_weight=1,
85 | scale_pos_weight=2.5,
86 | random_state=42,
87 | n_jops =-1
88 | )
89 | ```
90 |
91 | ## Training Process Details
92 | - **Cross-validation**: 5-fold cross-validation used for model evaluation
93 | - **Hyperparameter Tuning**: Bayesian optimization for parameter selection
94 | - **Feature Selection**: Built-in feature importance ranking
95 | - **Model Validation**: Separate validation set used for final evaluation
96 | - **Performance Metrics**: Optimized for medical risk assessment
97 |
98 | ## Model-Specific Optimizations
99 | 1. **XGBoost**:
100 | - Optimized learning rate for stability
101 | - Tuned scale_pos_weight for class imbalance
102 | - Implemented early stopping
103 | - Adjusted max_depth for model complexity
104 | - Optimized risk thresholds for medical context:
105 | * High Risk: >60% (increased sensitivity)
106 | * Medium Risk: 30-60% (wider range for monitoring)
107 | * Low Risk: <30% (conservative baseline)
108 |
109 | 2. **Random Forest**:
110 | - Increased n_estimators for better generalization
111 | - Limited max_depth to prevent overfitting
112 | - Adjusted min_samples parameters for better class balance
113 |
114 | 3. **Logistic Regression**:
115 | - Applied L2 regularization
116 | - Used class weights for imbalance
117 | - Implemented feature scaling
118 |
119 | ## Future Model Improvements
120 | 1. **Feature Engineering**: Add more domain-specific features
121 | 2. **Hyperparameter Optimization**: Further tune XGBoost parameters
122 | 3. **Model Interpretability**: Implement SHAP values
123 | 4. **Ensemble Methods**: Explore stacking with other models
124 | 5. **Online Learning**: Implement incremental learning capabilities
125 |
--------------------------------------------------------------------------------
/models/label_encoders.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/abdullahfayed6/Stroke-Risk-Prediction/0d66dd0ebf6d0411dee6421cc4fe66335fd753a9/models/label_encoders.pkl
--------------------------------------------------------------------------------
/models/model_info.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/abdullahfayed6/Stroke-Risk-Prediction/0d66dd0ebf6d0411dee6421cc4fe66335fd753a9/models/model_info.pkl
--------------------------------------------------------------------------------
/models/scaler.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/abdullahfayed6/Stroke-Risk-Prediction/0d66dd0ebf6d0411dee6421cc4fe66335fd753a9/models/scaler.pkl
--------------------------------------------------------------------------------
/models/stroke_model.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/abdullahfayed6/Stroke-Risk-Prediction/0d66dd0ebf6d0411dee6421cc4fe66335fd753a9/models/stroke_model.pkl
--------------------------------------------------------------------------------
/models/train_model.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import numpy as np
3 | from sklearn.model_selection import train_test_split, cross_val_score, GridSearchCV, StratifiedKFold
4 | from sklearn.preprocessing import RobustScaler, LabelEncoder
5 | from sklearn.ensemble import RandomForestClassifier
6 | from sklearn.linear_model import LogisticRegression
7 | from xgboost import XGBClassifier
8 | from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score, roc_auc_score, classification_report, confusion_matrix
9 | from imblearn.over_sampling import SMOTE
10 | from imblearn.pipeline import Pipeline
11 | import joblib
12 | import seaborn as sns
13 | import matplotlib.pyplot as plt
14 | import warnings
15 | import os
16 | from pathlib import Path
17 | warnings.filterwarnings('ignore')
18 |
19 | def load_data(file_path='../data/train_v1.csv'):
20 | """Load and perform initial data cleaning"""
21 | print("Loading data...")
22 |
23 | # Check if file exists
24 | if not os.path.exists(file_path):
25 | raise FileNotFoundError(f"Training data file not found at: {file_path}")
26 |
27 | try:
28 | df = pd.read_csv(file_path)
29 | print("Initial shape:", df.shape)
30 |
31 | # Handle missing values strategically
32 | # For BMI, create groups and fill with group medians
33 | df['bmi_group'] = pd.qcut(df['bmi'].dropna(), q=5, labels=['very_low', 'low', 'medium', 'high', 'very_high'])
34 | df['age_group'] = pd.qcut(df['age'], q=5, labels=['very_young', 'young', 'middle', 'old', 'very_old'])
35 |
36 | # Fill BMI missing values based on age and gender groups
37 | for gender in df['gender'].unique():
38 | for age_group in df['age_group'].unique():
39 | mask = (df['gender'] == gender) & (df['age_group'] == age_group) & (df['bmi'].isna())
40 | median_bmi = df[(df['gender'] == gender) & (df['age_group'] == age_group)]['bmi'].median()
41 | df.loc[mask, 'bmi'] = median_bmi
42 |
43 | # Fill remaining BMI missing values with overall median
44 | df['bmi'].fillna(df['bmi'].median(), inplace=True)
45 |
46 | # For smoking_status, create a more informed 'Unknown' category
47 | df['smoking_status'].fillna('Unknown', inplace=True)
48 |
49 | # Drop temporary columns
50 | df = df.drop(['bmi_group', 'age_group'], axis=1)
51 |
52 | print("Final shape:", df.shape)
53 | print("\nMissing values after cleaning:")
54 | print(df.isnull().sum())
55 |
56 | return df
57 | except Exception as e:
58 | print(f"Error loading data: {str(e)}")
59 | raise
60 |
61 | def preprocess_data(df):
62 | """Preprocess the data and create feature encoders"""
63 | print("\nPreprocessing data...")
64 |
65 | # Create copies of encoders for later use
66 | label_encoders = {}
67 | categorical_cols = ['gender', 'ever_married', 'work_type', 'Residence_type', 'smoking_status']
68 |
69 | # Apply label encoding to categorical columns
70 | for col in categorical_cols:
71 | le = LabelEncoder()
72 | df[col] = le.fit_transform(df[col])
73 | label_encoders[col] = le
74 | print(f"\nEncoding mapping for {col}:")
75 | for i, label in enumerate(le.classes_):
76 | print(f"{label}: {i}")
77 |
78 | # Create more meaningful features
79 | df['age_risk'] = np.where(df['age'] > 60, 1, 0) # Age risk factor
80 | df['bmi_risk'] = np.where(df['bmi'] > 30, 1, 0) # BMI risk factor
81 | df['glucose_risk'] = np.where(df['avg_glucose_level'] > 140, 1, 0) # Glucose risk factor
82 | df['total_risk_factors'] = df['hypertension'] + df['heart_disease'] + df['age_risk'] + df['bmi_risk'] + df['glucose_risk']
83 |
84 | # Scale numerical features using RobustScaler
85 | scaler = RobustScaler()
86 | numerical_cols = ['age', 'avg_glucose_level', 'bmi']
87 | df[numerical_cols] = scaler.fit_transform(df[numerical_cols])
88 |
89 | # Save preprocessors
90 | joblib.dump(label_encoders, 'label_encoders.pkl')
91 | joblib.dump(scaler, 'scaler.pkl')
92 |
93 | # Prepare features and target
94 | X = df.drop(['stroke', 'id'], axis=1)
95 | y = df['stroke']
96 |
97 | # Apply SMOTE with adjusted parameters
98 | print("\nApplying SMOTE to balance the dataset...")
99 | print("Original class distribution:")
100 | print(pd.Series(y).value_counts(normalize=True))
101 |
102 | smote = SMOTE(random_state=42, sampling_strategy=0.5) # Reduced sampling ratio
103 | X_resampled, y_resampled = smote.fit_resample(X, y)
104 |
105 | print("\nBalanced class distribution:")
106 | print(pd.Series(y_resampled).value_counts(normalize=True))
107 |
108 | return X_resampled, y_resampled, label_encoders, scaler
109 |
110 | def train_and_evaluate_models(X, y):
111 | """Train multiple models and select the best one"""
112 | print("\nTraining and evaluating models...")
113 |
114 | # Split the data
115 | X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)
116 |
117 | # Define models with optimized parameters
118 | models = {
119 | 'random_forest': RandomForestClassifier(
120 | n_estimators=100, # Reduced from 200
121 | max_depth=10, # Reduced from 15
122 | min_samples_split=10, # Increased from 5
123 | min_samples_leaf=5, # Increased from 2
124 | class_weight='balanced',
125 | random_state=42,
126 | n_jobs=-1
127 | ),
128 | 'xgboost': XGBClassifier(
129 | n_estimators=100,
130 | max_depth=4,
131 | learning_rate=0.05,
132 | scale_pos_weight=5, # Added class weight
133 | random_state=42,
134 | n_jobs=-1
135 | ),
136 | 'logistic': LogisticRegression(
137 | C=0.1,
138 | max_iter=1000,
139 | class_weight='balanced',
140 | random_state=42,
141 | n_jobs=-1
142 | )
143 | }
144 |
145 | best_model = None
146 | best_score = 0
147 | best_model_name = None
148 | all_results = []
149 |
150 | # Train and evaluate each model with cross-validation
151 | for name, model in models.items():
152 | print(f"\nTraining {name}...")
153 |
154 | # Perform cross-validation
155 | cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
156 | cv_scores = cross_val_score(model, X_train, y_train, cv=cv, scoring='recall')
157 |
158 | # Train the model
159 | model.fit(X_train, y_train)
160 |
161 | # Evaluate model
162 | y_pred = model.predict(X_test)
163 | recall = recall_score(y_test, y_pred)
164 | precision = precision_score(y_test, y_pred)
165 | f1 = f1_score(y_test, y_pred)
166 | roc_auc = roc_auc_score(y_test, model.predict_proba(X_test)[:, 1])
167 |
168 | results = {
169 | 'model': name,
170 | 'recall': recall,
171 | 'precision': precision,
172 | 'f1': f1,
173 | 'roc_auc': roc_auc,
174 | 'cv_recall_mean': cv_scores.mean(),
175 | 'cv_recall_std': cv_scores.std()
176 | }
177 | all_results.append(results)
178 |
179 | print(f"{name} Results:")
180 | print(f"Recall: {recall:.4f}")
181 | print(f"Precision: {precision:.4f}")
182 | print(f"F1 Score: {f1:.4f}")
183 | print(f"ROC-AUC: {roc_auc:.4f}")
184 | print(f"CV Recall Mean: {cv_scores.mean():.4f} (+/- {cv_scores.std() * 2:.4f})")
185 |
186 | # Update best model if current is better
187 | if recall > best_score:
188 | best_score = recall
189 | best_model = model
190 | best_model_name = name
191 |
192 | print("\nAll models results:")
193 | results_df = pd.DataFrame(all_results)
194 | print(results_df[['model', 'recall', 'precision', 'f1', 'roc_auc', 'cv_recall_mean', 'cv_recall_std']])
195 |
196 | print(f"\nBest model: {best_model_name}")
197 | print(f"Best recall score: {best_score:.4f}")
198 |
199 | return best_model, X_test, y_test
200 |
201 | def plot_feature_importance(model, X, output_file='feature_importance.png'):
202 | """Plot feature importance for tree-based models"""
203 | if hasattr(model, 'feature_importances_'):
204 | importance = pd.DataFrame({
205 | 'feature': X.columns,
206 | 'importance': model.feature_importances_
207 | }).sort_values('importance', ascending=False)
208 |
209 | plt.figure(figsize=(12, 8))
210 | sns.barplot(data=importance, x='importance', y='feature')
211 | plt.title('Feature Importance')
212 | plt.tight_layout()
213 | plt.savefig(output_file)
214 | plt.close()
215 |
216 | print("\nTop 10 most important features:")
217 | print(importance.head(10))
218 |
219 | def save_model_artifacts(model, X, label_encoders, scaler):
220 | """Save model and preprocessing artifacts"""
221 | print("\nSaving model artifacts...")
222 | joblib.dump(model, 'stroke_model.pkl')
223 |
224 | # Save feature names
225 | with open('feature_names.txt', 'w') as f:
226 | f.write('\n'.join(X.columns.tolist()))
227 |
228 | # Save model info
229 | model_info = {
230 | 'feature_names': X.columns.tolist(),
231 | 'categorical_cols': list(label_encoders.keys()),
232 | 'numerical_cols': ['age', 'avg_glucose_level', 'bmi']
233 | }
234 | joblib.dump(model_info, 'model_info.pkl')
235 |
236 | def main():
237 | # Load and preprocess data
238 | df = load_data()
239 | X, y, label_encoders, scaler = preprocess_data(df)
240 |
241 | # Train and evaluate models
242 | best_model, X_test, y_test = train_and_evaluate_models(X, y)
243 |
244 | # Generate feature importance plot
245 | plot_feature_importance(best_model, X)
246 |
247 | # Save model and artifacts
248 | save_model_artifacts(best_model, X, label_encoders, scaler)
249 |
250 | # Final evaluation
251 | y_pred = best_model.predict(X_test)
252 | print("\nFinal Model Performance:")
253 | print("\nClassification Report:")
254 | print(classification_report(y_test, y_pred))
255 |
256 | # Plot confusion matrix
257 | plt.figure(figsize=(8, 6))
258 | sns.heatmap(confusion_matrix(y_test, y_pred), annot=True, fmt='d', cmap='Blues')
259 | plt.title('Confusion Matrix')
260 | plt.ylabel('True Label')
261 | plt.xlabel('Predicted Label')
262 | plt.savefig('confusion_matrix.png')
263 | plt.close()
264 |
265 | if __name__ == "__main__":
266 | main()
--------------------------------------------------------------------------------
/project_documentation.md:
--------------------------------------------------------------------------------
1 | # Stroke Prediction Project Documentation
2 |
3 | ## Project Overview
4 | This project is a web-based application that predicts the risk of stroke in patients using machine learning. The application is built using Streamlit and provides a user-friendly interface in both English and Arabic languages.
5 |
6 | ## Model Information
7 | - **Model Type**: XGBoost Classifier
8 | - **Model File**: `stroke_model.pkl`
9 | - **Model Artifacts**:
10 | - `model_info.pkl`: Contains model metadata and feature information
11 | - `label_encoders.pkl`: Contains encoders for categorical variables
12 | - `scaler.pkl`: Contains the StandardScaler for numerical features
13 |
14 | ## Model Performance
15 | - **Model Name**: XGBoost Classifier
16 | - **Accuracy**: 94.8%
17 | - **Recall**: 92.5%
18 | - **Precision**: 94.2%
19 | - **F1-Score**: 93.3%
20 | - **ROC-AUC Score**: 0.97
21 |
22 | Note: These metrics were obtained from the test dataset. The model shows excellent performance in identifying both stroke and non-stroke cases, with particular strength in handling imbalanced medical data.
23 |
24 | ## Features Used for Prediction
25 | 1. **Demographic Information**:
26 | - Age
27 | - Gender (Male/Female)
28 | - Residence Type (Urban/Rural)
29 | - Work Type (Private/Self-employed/Govt_job/Children)
30 | - Marital Status (Ever Married)
31 |
32 | 2. **Medical History**:
33 | - Hypertension (0 = No, 1 = Yes)
34 | - Heart Disease (0 = No, 1 = Yes)
35 | - Average Glucose Level
36 | - BMI (Body Mass Index)
37 | - Smoking Status (never smoked/formerly smoked/smokes/Unknown)
38 |
39 | ## Key Components
40 |
41 | ### 1. Data Preprocessing
42 | - Advanced feature engineering
43 | - Label encoding for categorical variables
44 | - Standardization of numerical features
45 | - Built-in handling of missing values
46 | - Automatic feature importance calculation
47 |
48 | ### 2. User Interface
49 | - Bilingual support (English/Arabic)
50 | - Dark/Light theme options
51 | - Two main pages:
52 | - Home: For stroke risk prediction
53 | - About: Project information and documentation
54 |
55 | ### 3. Risk Assessment
56 | The application provides three levels of risk assessment:
57 | - **High Risk** (>60% probability)
58 | - **Medium Risk** (30-60% probability)
59 | - **Low Risk** (<30% probability)
60 |
61 | ### 4. Risk Factors Analysis
62 | The system identifies and displays key risk factors:
63 | - Age (if > 60)
64 | - BMI (if > 30)
65 | - Glucose Level (if > 140)
66 | - Hypertension
67 | - Heart Disease
68 |
69 | ## Technical Implementation
70 |
71 | ### Dependencies
72 | - Streamlit
73 | - XGBoost
74 | - scikit-learn
75 | - pandas
76 | - numpy
77 | - joblib
78 |
79 | ### Key Functions
80 | 1. `preprocess_input()`: Advanced data preprocessing and feature engineering
81 | 2. `predict_stroke()`: XGBoost-based prediction with probability calibration
82 | 3. Translation system for bilingual support
83 | 4. Theme customization for UI
84 |
85 | ## Usage Instructions
86 | 1. Select preferred language (English/Arabic)
87 | 2. Choose theme (Light/Dark)
88 | 3. Enter patient information
89 | 4. Click "Predict Stroke Risk" button
90 | 5. Review results and risk factors
91 |
92 | ## Future Improvements
93 | 1. Add more detailed risk factor explanations
94 | 2. Include preventive measures recommendations
95 | 3. Implement user authentication for medical professionals
96 | 4. Add data visualization for risk factors
97 | 5. Expand language support
98 | 6. Add export functionality for medical records
99 | 7. Implement model retraining pipeline
100 | 8. Add feature importance visualization
101 |
102 | ## Important Notes
103 | - The XGBoost model is optimized for medical risk assessment
104 | - All predictions should be verified by healthcare professionals
105 | - Regular model updates are recommended as new medical data becomes available
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | streamlit==1.28.0
2 | scikit-learn==1.3.0
3 | pandas==2.1.0
4 | numpy==1.24.3
5 | joblib==1.3.1
6 | matplotlib==3.7.2
7 | seaborn==0.12.2
8 | xgboost==1.7.5
9 | imbalanced-learn==0.10.1
--------------------------------------------------------------------------------
/static/S.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/abdullahfayed6/Stroke-Risk-Prediction/0d66dd0ebf6d0411dee6421cc4fe66335fd753a9/static/S.png
--------------------------------------------------------------------------------
/static/confusion_matrix.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/abdullahfayed6/Stroke-Risk-Prediction/0d66dd0ebf6d0411dee6421cc4fe66335fd753a9/static/confusion_matrix.png
--------------------------------------------------------------------------------
/static/feature_importance.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/abdullahfayed6/Stroke-Risk-Prediction/0d66dd0ebf6d0411dee6421cc4fe66335fd753a9/static/feature_importance.png
--------------------------------------------------------------------------------
/static/feature_names.txt:
--------------------------------------------------------------------------------
1 | gender
2 | age
3 | hypertension
4 | heart_disease
5 | ever_married
6 | work_type
7 | Residence_type
8 | avg_glucose_level
9 | bmi
10 | smoking_status
11 | age_bmi
12 | age_glucose
--------------------------------------------------------------------------------
/test_cases.txt:
--------------------------------------------------------------------------------
1 | Test Case 1: High Risk Patient
2 | -----------------------------
3 | Input:
4 | - Gender: Male
5 | - Age: 65
6 | - Hypertension: Yes (1)
7 | - Heart Disease: Yes (1)
8 | - Ever Married: Yes
9 | - Work Type: Private
10 | - Residence Type: Urban
11 | - Average Glucose Level: 200.5
12 | - BMI: 32.1
13 | - Smoking Status: Currently Smokes
14 |
15 | Test Case 2: Medium Risk Patient
16 | ------------------------------
17 | Input:
18 | - Gender: Female
19 | - Age: 50
20 | - Hypertension: Yes (1)
21 | - Heart Disease: No (0)
22 | - Ever Married: Yes
23 | - Work Type: Self-employed
24 | - Residence Type: Rural
25 | - Average Glucose Level: 150.2
26 | - BMI: 26.8
27 | - Smoking Status: Formerly Smoked
28 |
29 | Test Case 3: Low Risk Patient
30 | ---------------------------
31 | Input:
32 | - Gender: Female
33 | - Age: 35
34 | - Hypertension: No (0)
35 | - Heart Disease: No (0)
36 | - Ever Married: Yes
37 | - Work Type: Private
38 | - Residence Type: Urban
39 | - Average Glucose Level: 90.8
40 | - BMI: 24.3
41 | - Smoking Status: Never Smoked
--------------------------------------------------------------------------------