├── README.md ├── mangrove_health_classification.py └── file /README.md: -------------------------------------------------------------------------------- 1 | Mangrove-Health-Classification-with-Sentinel-2-and-ML 2 | 3 | 🌱 Mangrove Health Classification using Sentinel-2 satellite imagery and machine learning techniques to monitor, assess, and predict the condition of mangrove ecosystems. This project combines remote sensing, geospatial analytics, and predictive modeling to support conservation and climate resilience strategies. 4 | 5 | 🚀 Project Overview 6 | 7 | Mangroves play a critical role in coastal protection, biodiversity, and carbon sequestration. However, they are threatened by deforestation, climate change, and human activities. 8 | This project leverages Sentinel-2 multispectral bands and machine learning classifiers (Random Forest, SVM, Gradient Boosting, etc.) to: 9 | 10 | Detect healthy vs degraded mangrove zones. 11 | 12 | Identify stress indicators through spectral signatures. 13 | 14 | Support long-term mangrove monitoring with reproducible workflows. 15 | 16 | 📂 Repository Structure 17 | ├── data/ # Sample Sentinel-2 datasets (links or instructions) 18 | ├── notebooks/ # Jupyter notebooks for preprocessing & ML training 19 | ├── src/ # Python source code for data processing & models 20 | ├── models/ # Saved trained models 21 | ├── results/ # Classification maps & performance metrics 22 | ├── requirements.txt # Dependencies 23 | └── README.md # Project documentation 24 | 25 | 🔑 Features 26 | 27 | Preprocessing pipeline for Sentinel-2 (cloud masking, atmospheric correction). 28 | 29 | Spectral indices (NDVI, NDWI, MSI) for mangrove health analysis. 30 | 31 | Machine learning classification for healthy vs stressed/degraded zones. 32 | 33 | Evaluation metrics (accuracy, F1-score, confusion matrix). 34 | 35 | Export of results in GeoTIFF and shapefile formats. 36 | 37 | ⚙️ Installation 38 | 39 | Clone the repository and install dependencies: 40 | 41 | git clone https://github.com/yourusername/Mangrove-Health-Classification-with-Sentinel-2-and-ML.git 42 | cd Mangrove-Health-Classification-with-Sentinel-2-and-ML 43 | pip install -r requirements.txt 44 | 45 | 🛰️ Data Sources 46 | 47 | Sentinel-2 L2A imagery (10–20 m resolution) → Copernicus Open Access Hub 48 | 49 | Shapefiles / AOI boundaries (local mangrove extent shapefile from Global Mangrove Watch 50 | ) 51 | 52 | 🧑‍💻 Usage 53 | 54 | Preprocess data: 55 | 56 | python src/preprocess.py --input data/sentinel2_raw/ --output data/processed/ 57 | 58 | 59 | Train classifier: 60 | 61 | python src/train_model.py --data data/processed/ --model models/rf.pkl 62 | 63 | 64 | Run classification: 65 | 66 | python src/classify.py --model models/rf.pkl --input data/processed/ --output results/ 67 | 68 | 📊 Results 69 | 70 | Health classification maps (healthy, stressed, degraded). 71 | 72 | Model accuracy >85% in test AOIs. 73 | 74 | Time-series analysis to track mangrove degradation trends. 75 | 76 | 🌍 Applications 77 | 78 | Coastal ecosystem management. 79 | 80 | Climate adaptation and mitigation projects. 81 | 82 | Blue carbon monitoring for carbon credit initiatives. 83 | 84 | Early warning systems for mangrove degradation. 85 | 86 | 🤝 Contributing 87 | 88 | Contributions are welcome! Please fork this repo, create a feature branch, and submit a pull request. 89 | 90 | 📜 License 91 | 92 | This project is licensed under the MIT License. See LICENSE 93 | for details. 94 | 95 | 📧 Contact 96 | 97 | Maintainer: Amos Meremu Dogiye 98 | Github: https://github.com/Dogiye12 99 | 📩 Email: 100 | 101 | 🔗 LinkedIn: https://www.linkedin.com/in/meremu-amos-993333314/ 102 | -------------------------------------------------------------------------------- /mangrove_health_classification.py: -------------------------------------------------------------------------------- 1 | """ 2 | Mangrove-Health-Classification-with-Sentinel-2-and-ML (Synthetic Demo) 3 | ---------------------------------------------------------------------- 4 | Creates a synthetic Sentinel‑2 feature dataset (>100 samples), derives spectral 5 | indices (NDVI, EVI, NDWI, MNDWI, NBR, red‑edge metrics), trains a ML classifier 6 | (RandomForest), and exports artifacts (CSV, model, and plots). 7 | """ 8 | from __future__ import annotations 9 | import argparse 10 | import os 11 | import warnings 12 | import numpy as np 13 | import pandas as pd 14 | import matplotlib.pyplot as plt 15 | from dataclasses import dataclass 16 | from typing import List 17 | 18 | from sklearn.model_selection import train_test_split 19 | from sklearn.ensemble import RandomForestClassifier 20 | from sklearn.metrics import ( 21 | accuracy_score, 22 | f1_score, 23 | classification_report, 24 | confusion_matrix, 25 | roc_auc_score, 26 | roc_curve, 27 | ) 28 | from sklearn.preprocessing import label_binarize 29 | import joblib 30 | 31 | warnings.filterwarnings("ignore", category=UserWarning) 32 | 33 | # ---------------------------- Synthetic data generation ---------------------------- # 34 | @dataclass 35 | class Config: 36 | n: int = 1200 37 | seed: int = 42 38 | 39 | 40 | def _clip01(x): 41 | return np.clip(x, 0.0, 1.0) 42 | 43 | 44 | def generate_synthetic_s2(cfg: Config) -> pd.DataFrame: 45 | rng = np.random.default_rng(cfg.seed) 46 | n = cfg.n 47 | 48 | canopy = rng.beta(5, 2, n) 49 | waterlog = rng.beta(2, 5, n) 50 | salinity = rng.beta(2.5, 3.5, n) 51 | disturbance = rng.beta(2, 4, n) 52 | 53 | dist_to_river_m = rng.gamma(2.0, 150.0, n) 54 | tide_level_m = rng.normal(0.5, 0.25, n) 55 | tide_level_m = np.clip(tide_level_m, 0.0, 1.5) 56 | 57 | B2 = _clip01(0.05 + 0.08 * (1 - canopy) + 0.02 * waterlog + rng.normal(0, 0.01, n)) 58 | B3 = _clip01(0.1 + 0.10 * waterlog + 0.02 * (1 - canopy) + rng.normal(0, 0.012, n)) 59 | B4 = _clip01(0.06 + 0.18 * (1 - canopy) + 0.02 * disturbance + rng.normal(0, 0.012, n)) 60 | 61 | RE_base = 0.2 + 0.35 * canopy - 0.05 * disturbance 62 | B5 = _clip01(RE_base + rng.normal(0, 0.015, n)) 63 | B6 = _clip01(RE_base + 0.02 + rng.normal(0, 0.015, n)) 64 | B7 = _clip01(RE_base + 0.04 + rng.normal(0, 0.015, n)) 65 | 66 | NIR_base = 0.25 + 0.45 * canopy - 0.05 * disturbance 67 | B8 = _clip01(NIR_base + rng.normal(0, 0.02, n)) 68 | B8A = _clip01(NIR_base - 0.02 + rng.normal(0, 0.02, n)) 69 | 70 | SWIR1_base = 0.08 + 0.25 * salinity - 0.08 * waterlog 71 | SWIR2_base = 0.06 + 0.30 * salinity - 0.06 * waterlog 72 | B11 = _clip01(SWIR1_base + rng.normal(0, 0.02, n)) 73 | B12 = _clip01(SWIR2_base + rng.normal(0, 0.02, n)) 74 | 75 | def safe_nd(a, b): 76 | return (a - b) / np.maximum(a + b, 1e-6) 77 | 78 | NDVI = safe_nd(B8, B4) 79 | EVI = 2.5 * (B8 - B4) / (B8 + 6 * B4 - 7.5 * B2 + 1e-6) 80 | NDWI = safe_nd(B3, B8) 81 | MNDWI = safe_nd(B3, B11) 82 | NBR = safe_nd(B8, B12) 83 | REIP = B6 - B4 84 | 85 | tex_NIR = np.abs(rng.normal(0.0, 0.02, n)) + 0.03 * disturbance 86 | tex_RE = np.abs(rng.normal(0.0, 0.02, n)) + 0.02 * (1 - canopy) 87 | 88 | health_score = ( 89 | 2.5 * NDVI + 0.6 * EVI - 0.4 * disturbance + 0.3 * NDWI 90 | - 0.3 * salinity + 0.2 * (1 - tex_NIR) + 0.15 * (REIP) 91 | - 0.15 * (dist_to_river_m / 2000.0) 92 | + 0.1 * (0.8 - B11) 93 | + rng.normal(0, 0.2, n) 94 | ) 95 | 96 | q1, q2 = np.quantile(health_score, [0.35, 0.70]) 97 | y = np.where(health_score < q1, 0, np.where(health_score < q2, 1, 2)) 98 | 99 | df = pd.DataFrame({ 100 | "B2": B2, "B3": B3, "B4": B4, "B5": B5, "B6": B6, "B7": B7, 101 | "B8": B8, "B8A": B8A, "B11": B11, "B12": B12, 102 | "NDVI": NDVI, "EVI": EVI, "NDWI": NDWI, "MNDWI": MNDWI, "NBR": NBR, 103 | "REIP": REIP, "tex_NIR": tex_NIR, "tex_RE": tex_RE, 104 | "dist_to_river_m": dist_to_river_m, "tide_level_m": tide_level_m, 105 | "label": y 106 | }) 107 | return df 108 | 109 | # ---------------------------- Training & evaluation ---------------------------- # 110 | 111 | def train_evaluate(df: pd.DataFrame, seed: int = 42): 112 | feature_cols = [c for c in df.columns if c != "label"] 113 | X = df[feature_cols].values 114 | y = df["label"].values 115 | 116 | X_train, X_test, y_train, y_test = train_test_split( 117 | X, y, test_size=0.25, random_state=seed, stratify=y 118 | ) 119 | 120 | clf = RandomForestClassifier( 121 | n_estimators=500, 122 | max_depth=None, 123 | min_samples_leaf=2, 124 | n_jobs=-1, 125 | random_state=1337, 126 | class_weight="balanced_subsample", 127 | ) 128 | clf.fit(X_train, y_train) 129 | 130 | y_pred = clf.predict(X_test) 131 | acc = accuracy_score(y_test, y_pred) 132 | f1m = f1_score(y_test, y_pred, average="macro") 133 | 134 | print("\n=== Mangrove Health Classification (Synthetic) ===") 135 | print(f"Samples: {len(df):,} Classes: {np.unique(y).tolist()}") 136 | print(f"Accuracy: {acc:.3f} F1-macro: {f1m:.3f}") 137 | print("\nClassification report:\n", classification_report(y_test, y_pred, digits=3)) 138 | 139 | return clf, (X_test, y_test, y_pred), feature_cols 140 | 141 | # ---------------------------- Plotting helpers ---------------------------- # 142 | 143 | def plot_confusion(y_true, y_pred, out_path: str): 144 | cm = confusion_matrix(y_true, y_pred, labels=[0,1,2]) 145 | fig, ax = plt.subplots(figsize=(5,4)) 146 | im = ax.imshow(cm, interpolation='nearest') 147 | ax.set_title('Confusion Matrix (0=degraded,1=stressed,2=healthy)') 148 | ax.set_xlabel('Predicted') 149 | ax.set_ylabel('True') 150 | ax.set_xticks([0,1,2]); ax.set_yticks([0,1,2]) 151 | for i in range(cm.shape[0]): 152 | for j in range(cm.shape[1]): 153 | ax.text(j, i, cm[i, j], ha='center', va='center') 154 | fig.tight_layout(); fig.savefig(out_path, dpi=160); plt.close(fig) 155 | 156 | def plot_feature_importance(model: RandomForestClassifier, feature_names: List[str], out_path: str, top_k: int = 20): 157 | importances = model.feature_importances_ 158 | idx = np.argsort(importances)[::-1][:top_k] 159 | fig, ax = plt.subplots(figsize=(8,6)) 160 | ax.barh([feature_names[i] for i in idx][::-1], importances[idx][::-1]) 161 | ax.set_xlabel('Gini Importance') 162 | ax.set_title('Top Feature Importances') 163 | fig.tight_layout(); fig.savefig(out_path, dpi=160); plt.close(fig) 164 | 165 | def plot_roc_ovr(model: RandomForestClassifier, X_test: np.ndarray, y_test: np.ndarray, out_path: str): 166 | y_bin = label_binarize(y_test, classes=[0,1,2]) 167 | y_prob = model.predict_proba(X_test) 168 | fig, ax = plt.subplots(figsize=(6,5)) 169 | for i, cls in enumerate([0,1,2]): 170 | fpr, tpr, _ = roc_curve(y_bin[:, i], y_prob[:, i]) 171 | auc = roc_auc_score(y_bin[:, i], y_prob[:, i]) 172 | ax.plot(fpr, tpr, label=f"Class {cls} AUC={auc:.2f}") 173 | ax.plot([0,1],[0,1], linestyle='--') 174 | ax.set_xlabel('FPR'); ax.set_ylabel('TPR'); ax.set_title('ROC (One-vs-Rest)') 175 | ax.legend() 176 | fig.tight_layout(); fig.savefig(out_path, dpi=160); plt.close(fig) 177 | 178 | # ---------------------------- CLI ---------------------------- # 179 | def main(): 180 | ap = argparse.ArgumentParser(description='Synthetic Sentinel-2 + ML for Mangrove Health Classification') 181 | ap.add_argument('--n', type=int, default=1200, help='Number of samples (>100)') 182 | ap.add_argument('--seed', type=int, default=42, help='Random seed') 183 | args = ap.parse_args() 184 | 185 | if args.n < 101: 186 | raise SystemExit('Please use --n >= 101 (requirement: >100 points)') 187 | 188 | cfg = Config(n=args.n, seed=args.seed) 189 | df = generate_synthetic_s2(cfg) 190 | 191 | os.makedirs('artifacts', exist_ok=True) 192 | df.to_csv('artifacts/mangrove_s2_synthetic.csv', index=False) 193 | 194 | model, eval_pack, feature_cols = train_evaluate(df, seed=args.seed) 195 | X_test, y_test, y_pred = eval_pack 196 | 197 | joblib.dump(model, 'artifacts/model_mangrove_health.joblib') 198 | plot_confusion(y_test, y_pred, 'artifacts/confusion_matrix.png') 199 | plot_feature_importance(model, feature_cols, 'artifacts/feature_importance.png') 200 | plot_roc_ovr(model, X_test, y_test, 'artifacts/roc_ovr.png') 201 | 202 | print('\nArtifacts saved in ./artifacts') 203 | 204 | if __name__ == '__main__': 205 | main() 206 | -------------------------------------------------------------------------------- /file: -------------------------------------------------------------------------------- 1 | """ 2 | Mangrove-Health-Classification-with-Sentinel-2-and-ML (Synthetic Demo) 3 | ---------------------------------------------------------------------- 4 | Creates a synthetic Sentinel‑2 feature dataset (>100 samples), derives spectral 5 | indices (NDVI, EVI, NDWI, MNDWI, NBR, red‑edge metrics), trains a ML classifier 6 | (RandomForest), and exports artifacts (CSV, model, and plots). 7 | 8 | Usage 9 | ----- 10 | python mangrove_health_classification.py --n 1200 --seed 42 11 | 12 | Artifacts 13 | --------- 14 | ./artifacts/ 15 | ├─ mangrove_s2_synthetic.csv 16 | ├─ model_mangrove_health.joblib 17 | ├─ confusion_matrix.png 18 | ├─ roc_ovr.png 19 | └─ feature_importance.png 20 | 21 | Requirements 22 | ------------ 23 | - numpy, pandas, scikit-learn, matplotlib, joblib 24 | """ 25 | from __future__ import annotations 26 | import argparse 27 | import os 28 | import warnings 29 | import numpy as np 30 | import pandas as pd 31 | import matplotlib.pyplot as plt 32 | from dataclasses import dataclass 33 | from typing import Tuple, List 34 | 35 | from sklearn.model_selection import train_test_split 36 | from sklearn.ensemble import RandomForestClassifier 37 | from sklearn.metrics import ( 38 | accuracy_score, 39 | f1_score, 40 | classification_report, 41 | confusion_matrix, 42 | roc_auc_score, 43 | roc_curve, 44 | ) 45 | from sklearn.preprocessing import label_binarize 46 | import joblib 47 | 48 | warnings.filterwarnings("ignore", category=UserWarning) 49 | 50 | # ---------------------------- Synthetic data generation ---------------------------- # 51 | @dataclass 52 | class Config: 53 | n: int = 1200 54 | seed: int = 42 55 | 56 | 57 | def _clip01(x): 58 | return np.clip(x, 0.0, 1.0) 59 | 60 | 61 | def generate_synthetic_s2(cfg: Config) -> pd.DataFrame: 62 | """Generate synthetic Sentinel‑2 reflectance and contextual features. 63 | 64 | Bands roughly follow S2: B2(Blue), B3(Green), B4(Red), B5-7(RE), B8(NIR), 65 | B8A(NIR-narrow), B11(SWIR1), B12(SWIR2). Values in [0,1]. 66 | """ 67 | rng = np.random.default_rng(cfg.seed) 68 | n = cfg.n 69 | 70 | # Base latent factors 71 | canopy = rng.beta(5, 2, n) # greenness/canopy vigor 72 | waterlog = rng.beta(2, 5, n) # waterlogging/wetness 73 | salinity = rng.beta(2.5, 3.5, n) # soil/water salinity proxy 74 | disturbance = rng.beta(2, 4, n) # human/erosion disturbance 75 | 76 | # Context features 77 | dist_to_river_m = rng.gamma(2.0, 150.0, n) # 0 - few km 78 | tide_level_m = rng.normal(0.5, 0.25, n) 79 | tide_level_m = np.clip(tide_level_m, 0.0, 1.5) 80 | 81 | # Construct band reflectances with stochastic structure 82 | # Greener canopy elevates NIR/RE, reduces Red; waterlog raises Green & lowers SWIR; salinity raises SWIR 83 | B2 = _clip01(0.05 + 0.08 * (1 - canopy) + 0.02 * waterlog + rng.normal(0, 0.01, n)) 84 | B3 = _clip01(0.1 + 0.10 * waterlog + 0.02 * (1 - canopy) + rng.normal(0, 0.012, n)) 85 | B4 = _clip01(0.06 + 0.18 * (1 - canopy) + 0.02 * disturbance + rng.normal(0, 0.012, n)) 86 | 87 | RE_base = 0.2 + 0.35 * canopy - 0.05 * disturbance 88 | B5 = _clip01(RE_base + rng.normal(0, 0.015, n)) 89 | B6 = _clip01(RE_base + 0.02 + rng.normal(0, 0.015, n)) 90 | B7 = _clip01(RE_base + 0.04 + rng.normal(0, 0.015, n)) 91 | 92 | NIR_base = 0.25 + 0.45 * canopy - 0.05 * disturbance 93 | B8 = _clip01(NIR_base + rng.normal(0, 0.02, n)) 94 | B8A = _clip01(NIR_base - 0.02 + rng.normal(0, 0.02, n)) 95 | 96 | SWIR1_base = 0.08 + 0.25 * salinity - 0.08 * waterlog 97 | SWIR2_base = 0.06 + 0.30 * salinity - 0.06 * waterlog 98 | B11 = _clip01(SWIR1_base + rng.normal(0, 0.02, n)) 99 | B12 = _clip01(SWIR2_base + rng.normal(0, 0.02, n)) 100 | 101 | # Spectral indices 102 | def safe_nd(a, b): 103 | return (a - b) / np.maximum(a + b, 1e-6) 104 | 105 | NDVI = safe_nd(B8, B4) 106 | EVI = 2.5 * (B8 - B4) / (B8 + 6 * B4 - 7.5 * B2 + 1e-6) 107 | NDWI = safe_nd(B3, B8) # McFeeters (1996) 108 | MNDWI = safe_nd(B3, B11) # Xu (2006) 109 | NBR = safe_nd(B8, B12) 110 | 111 | # Red‑edge position proxy (simple difference between RE and Red) 112 | REIP = B6 - B4 113 | 114 | # Texture proxies (synthetic): local variance approximations 115 | tex_NIR = np.abs(rng.normal(0.0, 0.02, n)) + 0.03 * disturbance 116 | tex_RE = np.abs(rng.normal(0.0, 0.02, n)) + 0.02 * (1 - canopy) 117 | 118 | # Health score (latent) combining drivers 119 | health_score = ( 120 | 2.5 * NDVI + 0.6 * EVI - 0.4 * disturbance + 0.3 * NDWI 121 | - 0.3 * salinity + 0.2 * (1 - tex_NIR) + 0.15 * (REIP) 122 | - 0.15 * (dist_to_river_m / 2000.0) 123 | + 0.1 * (0.8 - B11) # penalize high SWIR1 (dry/saline) 124 | + rng.normal(0, 0.2, n) 125 | ) 126 | 127 | # Discretize into classes: 0=degraded, 1=stressed, 2=healthy 128 | q1, q2 = np.quantile(health_score, [0.35, 0.70]) 129 | y = np.where(health_score < q1, 0, np.where(health_score < q2, 1, 2)) 130 | 131 | df = pd.DataFrame({ 132 | "B2": B2, "B3": B3, "B4": B4, "B5": B5, "B6": B6, "B7": B7, 133 | "B8": B8, "B8A": B8A, "B11": B11, "B12": B12, 134 | "NDVI": NDVI, "EVI": EVI, "NDWI": NDWI, "MNDWI": MNDWI, "NBR": NBR, 135 | "REIP": REIP, "tex_NIR": tex_NIR, "tex_RE": tex_RE, 136 | "dist_to_river_m": dist_to_river_m, "tide_level_m": tide_level_m, 137 | "label": y 138 | }) 139 | 140 | return df 141 | 142 | # ---------------------------- Training & evaluation ---------------------------- # 143 | 144 | def train_evaluate(df: pd.DataFrame, seed: int = 42): 145 | feature_cols = [c for c in df.columns if c != "label"] 146 | X = df[feature_cols].values 147 | y = df["label"].values 148 | 149 | X_train, X_test, y_train, y_test = train_test_split( 150 | X, y, test_size=0.25, random_state=seed, stratify=y 151 | ) 152 | 153 | clf = RandomForestClassifier( 154 | n_estimators=500, 155 | max_depth=None, 156 | min_samples_leaf=2, 157 | n_jobs=-1, 158 | random_state=1337, 159 | class_weight="balanced_subsample", 160 | ) 161 | 162 | clf.fit(X_train, y_train) 163 | 164 | y_pred = clf.predict(X_test) 165 | acc = accuracy_score(y_test, y_pred) 166 | f1m = f1_score(y_test, y_pred, average="macro") 167 | 168 | print("\n=== Mangrove Health Classification (Synthetic) ===") 169 | print(f"Samples: {len(df):,} Classes: {np.unique(y).tolist()}") 170 | print(f"Accuracy: {acc:.3f} F1-macro: {f1m:.3f}") 171 | print("\nClassification report:\n", classification_report(y_test, y_pred, digits=3)) 172 | 173 | return clf, (X_test, y_test, y_pred), feature_cols 174 | 175 | # ---------------------------- Plotting helpers ---------------------------- # 176 | 177 | def plot_confusion(y_true, y_pred, out_path: str): 178 | cm = confusion_matrix(y_true, y_pred, labels=[0,1,2]) 179 | fig, ax = plt.subplots(figsize=(5,4)) 180 | im = ax.imshow(cm, interpolation='nearest') 181 | ax.set_title('Confusion Matrix (0=degraded,1=stressed,2=healthy)') 182 | ax.set_xlabel('Predicted') 183 | ax.set_ylabel('True') 184 | ax.set_xticks([0,1,2]); ax.set_yticks([0,1,2]) 185 | for i in range(cm.shape[0]): 186 | for j in range(cm.shape[1]): 187 | ax.text(j, i, cm[i, j], ha='center', va='center') 188 | fig.tight_layout(); fig.savefig(out_path, dpi=160); plt.close(fig) 189 | 190 | 191 | def plot_feature_importance(model: RandomForestClassifier, feature_names: List[str], out_path: str, top_k: int = 20): 192 | importances = model.feature_importances_ 193 | idx = np.argsort(importances)[::-1][:top_k] 194 | fig, ax = plt.subplots(figsize=(8,6)) 195 | ax.barh([feature_names[i] for i in idx][::-1], importances[idx][::-1]) 196 | ax.set_xlabel('Gini Importance') 197 | ax.set_title('Top Feature Importances') 198 | fig.tight_layout(); fig.savefig(out_path, dpi=160); plt.close(fig) 199 | 200 | 201 | def plot_roc_ovr(model: RandomForestClassifier, X_test: np.ndarray, y_test: np.ndarray, out_path: str): 202 | y_bin = label_binarize(y_test, classes=[0,1,2]) 203 | y_prob = model.predict_proba(X_test) 204 | # One-vs-rest ROC (macro avg) 205 | fig, ax = plt.subplots(figsize=(6,5)) 206 | aucs = [] 207 | for i, cls in enumerate([0,1,2]): 208 | fpr, tpr, _ = roc_curve(y_bin[:, i], y_prob[:, i]) 209 | auc = roc_auc_score(y_bin[:, i], y_prob[:, i]) 210 | aucs.append(auc) 211 | ax.plot(fpr, tpr, label=f"Class {cls} AUC={auc:.2f}") 212 | ax.plot([0,1],[0,1], linestyle='--') 213 | ax.set_xlabel('FPR'); ax.set_ylabel('TPR'); ax.set_title('ROC (One-vs-Rest)') 214 | ax.legend() 215 | fig.tight_layout(); fig.savefig(out_path, dpi=160); plt.close(fig) 216 | 217 | # ---------------------------- CLI ---------------------------- # 218 | 219 | def main(): 220 | ap = argparse.ArgumentParser(description='Synthetic Sentinel-2 + ML for Mangrove Health Classification') 221 | ap.add_argument('--n', type=int, default=1200, help='Number of samples (>100)') 222 | ap.add_argument('--seed', type=int, default=42, help='Random seed') 223 | args = ap.parse_args() 224 | 225 | if args.n < 101: 226 | raise SystemExit('Please use --n >= 101 (requirement: >100 points)') 227 | 228 | cfg = Config(n=args.n, seed=args.seed) 229 | df = generate_synthetic_s2(cfg) 230 | 231 | os.makedirs('artifacts', exist_ok=True) 232 | df.to_csv('artifacts/mangrove_s2_synthetic.csv', index=False) 233 | 234 | model, eval_pack, feature_cols = train_evaluate(df, seed=args.seed) 235 | X_test, y_test, y_pred = eval_pack 236 | 237 | # Save model 238 | joblib.dump(model, 'artifacts/model_mangrove_health.joblib') 239 | 240 | # Plots 241 | plot_confusion(y_test, y_pred, 'artifacts/confusion_matrix.png') 242 | plot_feature_importance(model, feature_cols, 'artifacts/feature_importance.png') 243 | plot_roc_ovr(model, X_test, y_test, 'artifacts/roc_ovr.png') 244 | 245 | print('\nArtifacts saved in ./artifacts') 246 | 247 | 248 | if __name__ == '__main__': 249 | main() 250 | --------------------------------------------------------------------------------