├── dataset.xlsx ├── roc_curve.png ├── confusion_matrix.png ├── feature_importance.png ├── metadata.json ├── report.txt ├── README.md ├── file └── eeg_stress_classification.py /dataset.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Okes2024/EEG-Signal-Classification-for-Stress-Level-Detection/HEAD/dataset.xlsx -------------------------------------------------------------------------------- /roc_curve.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Okes2024/EEG-Signal-Classification-for-Stress-Level-Detection/HEAD/roc_curve.png -------------------------------------------------------------------------------- /confusion_matrix.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Okes2024/EEG-Signal-Classification-for-Stress-Level-Detection/HEAD/confusion_matrix.png -------------------------------------------------------------------------------- /feature_importance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Okes2024/EEG-Signal-Classification-for-Stress-Level-Detection/HEAD/feature_importance.png -------------------------------------------------------------------------------- /metadata.json: -------------------------------------------------------------------------------- 1 | { 2 | "cv_accuracy_mean": 1.0, 3 | "cv_accuracy_std": 0.0, 4 | "roc_auc": 1.0, 5 | "n_train": 450, 6 | "n_test": 150, 7 | "n_features": 28 8 | } -------------------------------------------------------------------------------- /report.txt: -------------------------------------------------------------------------------- 1 | EEG Stress Classification (Synthetic) 2 | ======================================== 3 | CV Accuracy (mean ± std): 1.000 ± 0.000 4 | 5 | Classification Report (Test): 6 | precision recall f1-score support 7 | 8 | Calm 1.00 1.00 1.00 75 9 | Stress 1.00 1.00 1.00 75 10 | 11 | accuracy 1.00 150 12 | macro avg 1.00 1.00 1.00 150 13 | weighted avg 1.00 1.00 1.00 150 14 | 15 | Confusion Matrix: 16 | [[75 0] 17 | [ 0 75]] 18 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | EEG-Signal-Classification-for-Stress-Level-Detection 2 | 3 | 🧠 EEG Signal Classification using synthetic data to detect stress levels. This project simulates multi-channel EEG signals, extracts frequency-band features, and trains a machine learning model to classify mental states as Calm or Stress. 4 | 5 | 🚀 Project Overview 6 | 7 | Stress detection from EEG signals is a growing field in neurotechnology and mental health monitoring. 8 | Since real EEG datasets can be difficult to obtain, this project uses synthetic EEG data to: 9 | 10 | Generate >100 EEG windows with realistic oscillatory patterns. 11 | 12 | Extract band power features (delta, theta, alpha, beta, gamma). 13 | 14 | Compute stress markers (e.g., beta/alpha ratio). 15 | 16 | Train a Random Forest classifier with cross-validation. 17 | 18 | Evaluate performance using confusion matrix, ROC curve, and feature importance plots. 19 | 20 | 📂 Repository Structure 21 | ├── eeg_stress_classification.py # Main script 22 | ├── outputs/ # Generated results 23 | │ ├── dataset.csv # Feature dataset 24 | │ ├── report.txt # Classification report 25 | │ ├── confusion_matrix.png # Confusion matrix (test set) 26 | │ ├── roc_curve.png # ROC curve (test set) 27 | │ ├── feature_importance.png # Top feature importance plot 28 | │ └── metadata.json # Performance metadata 29 | └── README.md # Project documentation 30 | 31 | 🔑 Features 32 | 33 | Synthetic EEG generation (calm vs stress). 34 | 35 | Band power extraction via Welch periodogram. 36 | 37 | Ratios for stress biomarkers (β/α, θ/α). 38 | 39 | Balanced dataset (>100 samples). 40 | 41 | Train/test split with 5-fold cross-validation. 42 | 43 | Export of dataset and evaluation plots. 44 | 45 | ⚙️ Installation 46 | 47 | Clone the repository and install dependencies: 48 | 49 | git clone https://github.com/yourusername/EEG-Signal-Classification-for-Stress-Level-Detection.git 50 | cd EEG-Signal-Classification-for-Stress-Level-Detection 51 | pip install numpy scipy pandas scikit-learn matplotlib 52 | 53 | 🧑‍💻 Usage 54 | 55 | Run the main script with default parameters (600 samples, 4 channels, 128 Hz, 2s windows): 56 | 57 | python eeg_stress_classification.py 58 | 59 | 60 | Customize parameters: 61 | 62 | python eeg_stress_classification.py --samples 1000 --channels 8 --fs 256 --duration 3.0 --outdir results 63 | 64 | 📊 Results 65 | 66 | Generated outputs include: 67 | 68 | dataset.csv → feature dataset with labels. 69 | 70 | report.txt → accuracy, precision, recall, F1-score. 71 | 72 | Confusion Matrix → classification performance. 73 | 74 | ROC Curve → AUC evaluation. 75 | 76 | Feature Importances → top EEG band features for stress detection. 77 | 78 | 🌍 Applications 79 | 80 | Cognitive load monitoring. 81 | 82 | Stress and fatigue detection. 83 | 84 | Brain-computer interfaces (BCI). 85 | 86 | Prototyping mental health monitoring tools. 87 | 88 | 🤝 Contributing 89 | 90 | Contributions are welcome! Please fork this repo, create a feature branch, and submit a pull request. 91 | 92 | 📜 License 93 | 94 | This project is licensed under the MIT License. See LICENSE 95 | for details. 96 | 97 | 📧 Contact 98 | 99 | Maintainer: Imoni Okes 100 | Github: https://github.com/Okes2024 101 | 102 | 🔗 LinkedIn 103 | -------------------------------------------------------------------------------- /file: -------------------------------------------------------------------------------- 1 | """ 2 | EEG-Signal-Classification-for-Stress-Level-Detection (Synthetic Data) 3 | -------------------------------------------------------------------- 4 | Generates synthetic multi-channel EEG windows for two classes: 5 | - 0 = Calm 6 | - 1 = Stress 7 | 8 | Workflow 9 | - Simulate EEG-like signals (alpha/beta differences) for each window 10 | - Extract band-power features via Welch periodogram 11 | - Train/test split, cross-val, RandomForest classifier 12 | - Save metrics, confusion matrix, ROC-AUC plots, and feature importances 13 | - Export the feature dataset to CSV 14 | 15 | Run 16 | python eeg_stress_classification.py --samples 600 --channels 4 --fs 128 --duration 2.0 17 | 18 | Outputs (created under ./outputs/) 19 | - dataset.csv 20 | - confusion_matrix.png 21 | - roc_curve.png 22 | - feature_importance.png 23 | - report.txt 24 | 25 | Author: Your Name 26 | License: MIT 27 | """ 28 | import argparse 29 | import json 30 | from pathlib import Path 31 | import numpy as np 32 | import pandas as pd 33 | from scipy.signal import welch, butter, filtfilt 34 | from sklearn.model_selection import train_test_split, StratifiedKFold, cross_val_score 35 | from sklearn.ensemble import RandomForestClassifier 36 | from sklearn.metrics import classification_report, confusion_matrix, roc_curve, auc 37 | import matplotlib.pyplot as plt 38 | from sklearn.preprocessing import StandardScaler 39 | 40 | # ---------------------- Signal Utilities ---------------------- 41 | def band_lims(): 42 | # (low, high) in Hz 43 | return { 44 | "delta": (0.5, 4), 45 | "theta": (4, 8), 46 | "alpha": (8, 13), 47 | "beta": (13, 30), 48 | "gamma": (30, 45) 49 | } 50 | 51 | def bandpower_from_psd(freqs, psd, fmin, fmax): 52 | mask = (freqs >= fmin) & (freqs < fmax) 53 | return np.trapz(psd[mask], freqs[mask]) 54 | 55 | def simulate_eeg_window(fs, duration, channels, label, rng): 56 | """ 57 | Create multi-channel EEG-like window. 58 | Calm: stronger alpha, moderate theta, lower beta. 59 | Stress: reduced alpha, elevated beta, slight gamma increase. 60 | """ 61 | n = int(fs * duration) 62 | t = np.arange(n) / fs 63 | 64 | # Base pink-ish noise (1/f) 65 | def pink_noise(n, rng): 66 | # Voss-McCartney style approx using cumulative sums of white noise at powers of two 67 | num_layers = 16 68 | white = rng.standard_normal((num_layers, n)) 69 | cum = np.cumsum(white, axis=1) 70 | weights = 2.0 ** (-np.arange(num_layers)) 71 | pink = (weights[:, None] * cum).sum(axis=0) 72 | pink = pink / np.std(pink) 73 | return pink 74 | 75 | X = np.zeros((channels, n)) 76 | for ch in range(channels): 77 | sig = pink_noise(n, rng) * 0.3 78 | 79 | # Add oscillatory components 80 | # Alpha component stronger for calm, weaker for stress 81 | alpha_amp = rng.uniform(1.2, 1.8) if label == 0 else rng.uniform(0.4, 0.9) 82 | alpha_freq = rng.uniform(9, 12) 83 | sig += alpha_amp * np.sin(2*np.pi*alpha_freq*t + rng.uniform(0, 2*np.pi)) 84 | 85 | # Beta component stronger for stress 86 | beta_amp = rng.uniform(0.5, 1.0) if label == 0 else rng.uniform(1.2, 2.0) 87 | beta_freq = rng.uniform(15, 22) 88 | sig += beta_amp * np.sin(2*np.pi*beta_freq*t + rng.uniform(0, 2*np.pi)) 89 | 90 | # Theta moderate for calm, similar for stress 91 | theta_amp = rng.uniform(0.6, 1.0) if label == 0 else rng.uniform(0.5, 0.9) 92 | theta_freq = rng.uniform(5, 7.5) 93 | sig += theta_amp * np.sin(2*np.pi*theta_freq*t + rng.uniform(0, 2*np.pi)) 94 | 95 | # Small gamma spurts for stress 96 | if label == 1: 97 | gamma_amp = rng.uniform(0.2, 0.6) 98 | gamma_freq = rng.uniform(32, 40) 99 | sig += gamma_amp * np.sin(2*np.pi*gamma_freq*t + rng.uniform(0, 2*np.pi)) 100 | 101 | # Mild bandpass 0.5–45 Hz to mimic EEG preprocessing 102 | b, a = butter(4, [0.5/(fs/2), 45/(fs/2)], btype='band') 103 | sig = filtfilt(b, a, sig) 104 | 105 | # Normalize per channel 106 | sig = (sig - np.mean(sig)) / (np.std(sig) + 1e-8) 107 | 108 | X[ch, :] = sig 109 | 110 | return X # shape (channels, n) 111 | 112 | def extract_bandpowers(window, fs): 113 | """Compute Welch PSD for each channel and integrate band powers.""" 114 | bands = band_lims() 115 | ch, n = window.shape 116 | feats = {} 117 | for c in range(ch): 118 | freqs, psd = welch(window[c, :], fs=fs, nperseg=min(256, n)) 119 | for bname, (lo, hi) in bands.items(): 120 | feats[f"ch{c+1}_{bname}"] = bandpower_from_psd(freqs, psd, lo, hi) 121 | # Ratios often used as stress markers 122 | feats[f"ch{c+1}_beta_alpha_ratio"] = feats[f"ch{c+1}_beta"] / (feats[f"ch{c+1}_alpha"] + 1e-8) 123 | feats[f"ch{c+1}_theta_alpha_ratio"] = feats[f"ch{c+1}_theta"] / (feats[f"ch{c+1}_alpha"] + 1e-8) 124 | return feats 125 | 126 | # ---------------------- Main Pipeline ---------------------- 127 | def generate_dataset(n_samples=600, channels=4, fs=128, duration=2.0, seed=42): 128 | rng = np.random.default_rng(seed) 129 | X_feats = [] 130 | y = [] 131 | for i in range(n_samples): 132 | label = 0 if i < (n_samples // 2) else 1 # balance 133 | window = simulate_eeg_window(fs, duration, channels, label, rng) 134 | feats = extract_bandpowers(window, fs) 135 | X_feats.append(feats) 136 | y.append(label) 137 | df = pd.DataFrame(X_feats) 138 | df["label"] = y 139 | return df 140 | 141 | def train_and_evaluate(df, outdir: Path, seed=42): 142 | outdir.mkdir(parents=True, exist_ok=True) 143 | 144 | X = df.drop(columns=["label"]).values 145 | y = df["label"].values 146 | 147 | scaler = StandardScaler() 148 | X_scaled = scaler.fit_transform(X) 149 | 150 | X_train, X_test, y_train, y_test = train_test_split( 151 | X_scaled, y, test_size=0.25, stratify=y, random_state=seed 152 | ) 153 | 154 | clf = RandomForestClassifier( 155 | n_estimators=300, 156 | max_depth=None, 157 | random_state=seed, 158 | n_jobs=-1 159 | ) 160 | # Cross-validation on training set 161 | cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=seed) 162 | cv_scores = cross_val_score(clf, X_train, y_train, cv=cv, scoring="accuracy") 163 | 164 | clf.fit(X_train, y_train) 165 | y_pred = clf.predict(X_test) 166 | y_proba = clf.predict_proba(X_test)[:, 1] 167 | 168 | # Reports 169 | report = classification_report(y_test, y_pred, target_names=["Calm", "Stress"]) 170 | cm = confusion_matrix(y_test, y_pred) 171 | 172 | # Save report 173 | report_txt = outdir / "report.txt" 174 | with report_txt.open("w") as f: 175 | f.write("EEG Stress Classification (Synthetic)\n") 176 | f.write("="*40 + "\n") 177 | f.write(f"CV Accuracy (mean ± std): {cv_scores.mean():.3f} ± {cv_scores.std():.3f}\n\n") 178 | f.write("Classification Report (Test):\n") 179 | f.write(report + "\n") 180 | f.write("Confusion Matrix:\n") 181 | f.write(np.array2string(cm) + "\n") 182 | 183 | # Confusion Matrix plot 184 | plt.figure() 185 | im = plt.imshow(cm, interpolation="nearest") 186 | plt.title("Confusion Matrix (Test)") 187 | plt.xticks([0,1], ["Calm", "Stress"]) 188 | plt.yticks([0,1], ["Calm", "Stress"]) 189 | for (i, j), v in np.ndenumerate(cm): 190 | plt.text(j, i, str(v), ha="center", va="center") 191 | plt.xlabel("Predicted") 192 | plt.ylabel("True") 193 | plt.colorbar(im, fraction=0.046, pad=0.04) 194 | plt.tight_layout() 195 | plt.savefig(outdir / "confusion_matrix.png", dpi=200) 196 | plt.close() 197 | 198 | # ROC curve 199 | fpr, tpr, _ = roc_curve(y_test, y_proba) 200 | roc_auc = auc(fpr, tpr) 201 | plt.figure() 202 | plt.plot(fpr, tpr, label=f"AUC = {roc_auc:.3f}") 203 | plt.plot([0,1], [0,1], linestyle="--") 204 | plt.xlabel("False Positive Rate") 205 | plt.ylabel("True Positive Rate") 206 | plt.title("ROC Curve (Test)") 207 | plt.legend(loc="lower right") 208 | plt.tight_layout() 209 | plt.savefig(outdir / "roc_curve.png", dpi=200) 210 | plt.close() 211 | 212 | # Feature importances 213 | importances = clf.feature_importances_ 214 | indices = np.argsort(importances)[::-1][:20] # top 20 215 | labels = [f for f in df.drop(columns=["label"]).columns[indices]] 216 | plt.figure(figsize=(8, 6)) 217 | plt.bar(range(len(indices)), importances[indices]) 218 | plt.xticks(range(len(indices)), labels, rotation=90) 219 | plt.ylabel("Importance") 220 | plt.title("Top Feature Importances") 221 | plt.tight_layout() 222 | plt.savefig(outdir / "feature_importance.png", dpi=200) 223 | plt.close() 224 | 225 | # Save scaled dataset to CSV 226 | cols = list(df.drop(columns=["label"]).columns) + ["label"] 227 | df_scaled = pd.DataFrame(np.column_stack([X_scaled, y]), columns=cols) 228 | df_scaled.to_csv(outdir / "dataset.csv", index=False) 229 | 230 | # Save metadata 231 | meta = { 232 | "cv_accuracy_mean": float(cv_scores.mean()), 233 | "cv_accuracy_std": float(cv_scores.std()), 234 | "roc_auc": float(roc_auc), 235 | "n_train": int(len(y_train)), 236 | "n_test": int(len(y_test)), 237 | "n_features": int(X.shape[1]) 238 | } 239 | with (outdir / "metadata.json").open("w") as f: 240 | json.dump(meta, f, indent=2) 241 | 242 | return { 243 | "report_path": str(report_txt), 244 | "confusion_matrix": str(outdir / "confusion_matrix.png"), 245 | "roc_curve": str(outdir / "roc_curve.png"), 246 | "feature_importance": str(outdir / "feature_importance.png"), 247 | "dataset_csv": str(outdir / "dataset.csv"), 248 | "metadata_json": str(outdir / "metadata.json") 249 | } 250 | 251 | def parse_args(): 252 | ap = argparse.ArgumentParser() 253 | ap.add_argument("--samples", type=int, default=600, help="Total number of windows (>= 100)") 254 | ap.add_argument("--channels", type=int, default=4, help="Number of EEG channels") 255 | ap.add_argument("--fs", type=float, default=128.0, help="Sampling frequency (Hz)") 256 | ap.add_argument("--duration", type=float, default=2.0, help="Window duration (seconds)") 257 | ap.add_argument("--seed", type=int, default=42) 258 | ap.add_argument("--outdir", type=str, default="outputs", help="Output directory") 259 | return ap.parse_args() 260 | 261 | def main(): 262 | args = parse_args() 263 | if args.samples < 100: 264 | raise SystemExit("--samples must be >= 100") 265 | 266 | outdir = Path(args.outdir) 267 | outdir.mkdir(parents=True, exist_ok=True) 268 | 269 | print("Generating synthetic EEG dataset...") 270 | df = generate_dataset( 271 | n_samples=args.samples, 272 | channels=args.channels, 273 | fs=args.fs, 274 | duration=args.duration, 275 | seed=args.seed 276 | ) 277 | 278 | print("Training classifier and creating outputs...") 279 | paths = train_and_evaluate(df, outdir, seed=args.seed) 280 | 281 | print("Done. Outputs:") 282 | for k, v in paths.items(): 283 | print(f"{k}: {v}") 284 | 285 | if __name__ == "__main__": 286 | main() 287 | -------------------------------------------------------------------------------- /eeg_stress_classification.py: -------------------------------------------------------------------------------- 1 | 2 | """ 3 | EEG-Signal-Classification-for-Stress-Level-Detection (Synthetic Data) 4 | -------------------------------------------------------------------- 5 | Generates synthetic multi-channel EEG windows for two classes: 6 | - 0 = Calm 7 | - 1 = Stress 8 | 9 | Workflow 10 | - Simulate EEG-like signals (alpha/beta differences) for each window 11 | - Extract band-power features via Welch periodogram 12 | - Train/test split, cross-val, RandomForest classifier 13 | - Save metrics, confusion matrix, ROC-AUC plots, and feature importances 14 | - Export the feature dataset to CSV 15 | 16 | Run 17 | python eeg_stress_classification.py --samples 600 --channels 4 --fs 128 --duration 2.0 18 | 19 | Outputs (created under ./outputs/) 20 | - dataset.csv 21 | - confusion_matrix.png 22 | - roc_curve.png 23 | - feature_importance.png 24 | - report.txt 25 | 26 | Author: Your Name 27 | License: MIT 28 | """ 29 | import argparse 30 | import json 31 | from pathlib import Path 32 | import numpy as np 33 | import pandas as pd 34 | from scipy.signal import welch, butter, filtfilt 35 | from sklearn.model_selection import train_test_split, StratifiedKFold, cross_val_score 36 | from sklearn.ensemble import RandomForestClassifier 37 | from sklearn.metrics import classification_report, confusion_matrix, roc_curve, auc 38 | import matplotlib.pyplot as plt 39 | from sklearn.preprocessing import StandardScaler 40 | 41 | # ---------------------- Signal Utilities ---------------------- 42 | def band_lims(): 43 | # (low, high) in Hz 44 | return { 45 | "delta": (0.5, 4), 46 | "theta": (4, 8), 47 | "alpha": (8, 13), 48 | "beta": (13, 30), 49 | "gamma": (30, 45) 50 | } 51 | 52 | def bandpower_from_psd(freqs, psd, fmin, fmax): 53 | mask = (freqs >= fmin) & (freqs < fmax) 54 | return np.trapz(psd[mask], freqs[mask]) 55 | 56 | def simulate_eeg_window(fs, duration, channels, label, rng): 57 | """ 58 | Create multi-channel EEG-like window. 59 | Calm: stronger alpha, moderate theta, lower beta. 60 | Stress: reduced alpha, elevated beta, slight gamma increase. 61 | """ 62 | n = int(fs * duration) 63 | t = np.arange(n) / fs 64 | 65 | # Base pink-ish noise (1/f) 66 | def pink_noise(n, rng): 67 | # Voss-McCartney style approx using cumulative sums of white noise at powers of two 68 | num_layers = 16 69 | white = rng.standard_normal((num_layers, n)) 70 | cum = np.cumsum(white, axis=1) 71 | weights = 2.0 ** (-np.arange(num_layers)) 72 | pink = (weights[:, None] * cum).sum(axis=0) 73 | pink = pink / np.std(pink) 74 | return pink 75 | 76 | X = np.zeros((channels, n)) 77 | for ch in range(channels): 78 | sig = pink_noise(n, rng) * 0.3 79 | 80 | # Add oscillatory components 81 | # Alpha component stronger for calm, weaker for stress 82 | alpha_amp = rng.uniform(1.2, 1.8) if label == 0 else rng.uniform(0.4, 0.9) 83 | alpha_freq = rng.uniform(9, 12) 84 | sig += alpha_amp * np.sin(2*np.pi*alpha_freq*t + rng.uniform(0, 2*np.pi)) 85 | 86 | # Beta component stronger for stress 87 | beta_amp = rng.uniform(0.5, 1.0) if label == 0 else rng.uniform(1.2, 2.0) 88 | beta_freq = rng.uniform(15, 22) 89 | sig += beta_amp * np.sin(2*np.pi*beta_freq*t + rng.uniform(0, 2*np.pi)) 90 | 91 | # Theta moderate for calm, similar for stress 92 | theta_amp = rng.uniform(0.6, 1.0) if label == 0 else rng.uniform(0.5, 0.9) 93 | theta_freq = rng.uniform(5, 7.5) 94 | sig += theta_amp * np.sin(2*np.pi*theta_freq*t + rng.uniform(0, 2*np.pi)) 95 | 96 | # Small gamma spurts for stress 97 | if label == 1: 98 | gamma_amp = rng.uniform(0.2, 0.6) 99 | gamma_freq = rng.uniform(32, 40) 100 | sig += gamma_amp * np.sin(2*np.pi*gamma_freq*t + rng.uniform(0, 2*np.pi)) 101 | 102 | # Mild bandpass 0.5–45 Hz to mimic EEG preprocessing 103 | b, a = butter(4, [0.5/(fs/2), 45/(fs/2)], btype='band') 104 | sig = filtfilt(b, a, sig) 105 | 106 | # Normalize per channel 107 | sig = (sig - np.mean(sig)) / (np.std(sig) + 1e-8) 108 | 109 | X[ch, :] = sig 110 | 111 | return X # shape (channels, n) 112 | 113 | def extract_bandpowers(window, fs): 114 | """Compute Welch PSD for each channel and integrate band powers.""" 115 | bands = band_lims() 116 | ch, n = window.shape 117 | feats = {} 118 | for c in range(ch): 119 | freqs, psd = welch(window[c, :], fs=fs, nperseg=min(256, n)) 120 | for bname, (lo, hi) in bands.items(): 121 | feats[f"ch{c+1}_{bname}"] = bandpower_from_psd(freqs, psd, lo, hi) 122 | # Ratios often used as stress markers 123 | feats[f"ch{c+1}_beta_alpha_ratio"] = feats[f"ch{c+1}_beta"] / (feats[f"ch{c+1}_alpha"] + 1e-8) 124 | feats[f"ch{c+1}_theta_alpha_ratio"] = feats[f"ch{c+1}_theta"] / (feats[f"ch{c+1}_alpha"] + 1e-8) 125 | return feats 126 | 127 | # ---------------------- Main Pipeline ---------------------- 128 | def generate_dataset(n_samples=600, channels=4, fs=128, duration=2.0, seed=42): 129 | rng = np.random.default_rng(seed) 130 | X_feats = [] 131 | y = [] 132 | for i in range(n_samples): 133 | label = 0 if i < (n_samples // 2) else 1 # balance 134 | window = simulate_eeg_window(fs, duration, channels, label, rng) 135 | feats = extract_bandpowers(window, fs) 136 | X_feats.append(feats) 137 | y.append(label) 138 | df = pd.DataFrame(X_feats) 139 | df["label"] = y 140 | return df 141 | 142 | def train_and_evaluate(df, outdir: Path, seed=42): 143 | outdir.mkdir(parents=True, exist_ok=True) 144 | 145 | X = df.drop(columns=["label"]).values 146 | y = df["label"].values 147 | 148 | scaler = StandardScaler() 149 | X_scaled = scaler.fit_transform(X) 150 | 151 | X_train, X_test, y_train, y_test = train_test_split( 152 | X_scaled, y, test_size=0.25, stratify=y, random_state=seed 153 | ) 154 | 155 | clf = RandomForestClassifier( 156 | n_estimators=300, 157 | max_depth=None, 158 | random_state=seed, 159 | n_jobs=-1 160 | ) 161 | # Cross-validation on training set 162 | cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=seed) 163 | cv_scores = cross_val_score(clf, X_train, y_train, cv=cv, scoring="accuracy") 164 | 165 | clf.fit(X_train, y_train) 166 | y_pred = clf.predict(X_test) 167 | y_proba = clf.predict_proba(X_test)[:, 1] 168 | 169 | # Reports 170 | report = classification_report(y_test, y_pred, target_names=["Calm", "Stress"]) 171 | cm = confusion_matrix(y_test, y_pred) 172 | 173 | # Save report 174 | report_txt = outdir / "report.txt" 175 | with report_txt.open("w") as f: 176 | f.write("EEG Stress Classification (Synthetic)\n") 177 | f.write("="*40 + "\n") 178 | f.write(f"CV Accuracy (mean ± std): {cv_scores.mean():.3f} ± {cv_scores.std():.3f}\n\n") 179 | f.write("Classification Report (Test):\n") 180 | f.write(report + "\n") 181 | f.write("Confusion Matrix:\n") 182 | f.write(np.array2string(cm) + "\n") 183 | 184 | # Confusion Matrix plot (single figure) 185 | plt.figure() 186 | im = plt.imshow(cm, interpolation="nearest") 187 | plt.title("Confusion Matrix (Test)") 188 | plt.xticks([0,1], ["Calm", "Stress"]) 189 | plt.yticks([0,1], ["Calm", "Stress"]) 190 | for (i, j), v in np.ndenumerate(cm): 191 | plt.text(j, i, str(v), ha="center", va="center") 192 | plt.xlabel("Predicted") 193 | plt.ylabel("True") 194 | plt.colorbar(im, fraction=0.046, pad=0.04) 195 | plt.tight_layout() 196 | plt.savefig(outdir / "confusion_matrix.png", dpi=200) 197 | plt.close() 198 | 199 | # ROC curve (single figure) 200 | fpr, tpr, _ = roc_curve(y_test, y_proba) 201 | roc_auc = auc(fpr, tpr) 202 | plt.figure() 203 | plt.plot(fpr, tpr, label=f"AUC = {roc_auc:.3f}") 204 | plt.plot([0,1], [0,1], linestyle="--") 205 | plt.xlabel("False Positive Rate") 206 | plt.ylabel("True Positive Rate") 207 | plt.title("ROC Curve (Test)") 208 | plt.legend(loc="lower right") 209 | plt.tight_layout() 210 | plt.savefig(outdir / "roc_curve.png", dpi=200) 211 | plt.close() 212 | 213 | # Feature importances (single figure) 214 | importances = clf.feature_importances_ 215 | indices = np.argsort(importances)[::-1][:20] # top 20 216 | labels = [f for f in df.drop(columns=["label"]).columns[indices]] 217 | plt.figure(figsize=(8, 6)) 218 | plt.bar(range(len(indices)), importances[indices]) 219 | plt.xticks(range(len(indices)), labels, rotation=90) 220 | plt.ylabel("Importance") 221 | plt.title("Top Feature Importances") 222 | plt.tight_layout() 223 | plt.savefig(outdir / "feature_importance.png", dpi=200) 224 | plt.close() 225 | 226 | # Save scaled dataset to CSV 227 | cols = list(df.drop(columns=["label"]).columns) + ["label"] 228 | df_scaled = pd.DataFrame(np.column_stack([X_scaled, y]), columns=cols) 229 | df_scaled.to_csv(outdir / "dataset.csv", index=False) 230 | 231 | # Save metadata 232 | meta = { 233 | "cv_accuracy_mean": float(cv_scores.mean()), 234 | "cv_accuracy_std": float(cv_scores.std()), 235 | "roc_auc": float(roc_auc), 236 | "n_train": int(len(y_train)), 237 | "n_test": int(len(y_test)), 238 | "n_features": int(X.shape[1]) 239 | } 240 | with (outdir / "metadata.json").open("w") as f: 241 | json.dump(meta, f, indent=2) 242 | 243 | return { 244 | "report_path": str(report_txt), 245 | "confusion_matrix": str(outdir / "confusion_matrix.png"), 246 | "roc_curve": str(outdir / "roc_curve.png"), 247 | "feature_importance": str(outdir / "feature_importance.png"), 248 | "dataset_csv": str(outdir / "dataset.csv"), 249 | "metadata_json": str(outdir / "metadata.json") 250 | } 251 | 252 | def parse_args(): 253 | ap = argparse.ArgumentParser() 254 | ap.add_argument("--samples", type=int, default=600, help="Total number of windows (>= 100)") 255 | ap.add_argument("--channels", type=int, default=4, help="Number of EEG channels") 256 | ap.add_argument("--fs", type=float, default=128.0, help="Sampling frequency (Hz)") 257 | ap.add_argument("--duration", type=float, default=2.0, help="Window duration (seconds)") 258 | ap.add_argument("--seed", type=int, default=42) 259 | ap.add_argument("--outdir", type=str, default="outputs", help="Output directory") 260 | return ap.parse_args() 261 | 262 | def main(): 263 | args = parse_args() 264 | if args.samples < 100: 265 | raise SystemExit("--samples must be >= 100") 266 | 267 | outdir = Path(args.outdir) 268 | outdir.mkdir(parents=True, exist_ok=True) 269 | 270 | print("Generating synthetic EEG dataset...") 271 | df = generate_dataset( 272 | n_samples=args.samples, 273 | channels=args.channels, 274 | fs=args.fs, 275 | duration=args.duration, 276 | seed=args.seed 277 | ) 278 | 279 | print("Training classifier and creating outputs...") 280 | paths = train_and_evaluate(df, outdir, seed=args.seed) 281 | 282 | print("Done. Outputs:") 283 | for k, v in paths.items(): 284 | print(f"{k}: {v}") 285 | 286 | if __name__ == "__main__": 287 | main() 288 | --------------------------------------------------------------------------------