├── LICENSE ├── README.md ├── data ├── Fake.csv └── True.csv ├── outputs ├── charts │ ├── confusion_matrix.png │ ├── pr_curve.png │ └── roc_curve.png ├── metrics.json ├── model.joblib ├── pipeline.joblib └── vectorizer.joblib ├── requirements.txt └── src ├── detect_fake_news.py ├── streamlit_app.py ├── text_clean.py ├── train_model.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Amir 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Fake News & Misinformation Detector 2 | 3 | Detect fake vs real news articles using Machine Learning, TF-IDF, and Logistic Regression, complete with training scripts, evaluation charts, and an interactive Streamlit web app. 4 | 5 | --- 6 | 7 | ## Table of Contents 8 | - [Overview](#-overview) 9 | - [Demo](#-demo) 10 | - [Project Structure](#-project-structure) 11 | - [Installation](#-installation) 12 | - [Dataset](#-dataset) 13 | - [Training the Model](#-training-the-model) 14 | - [Evaluation & Charts](#-evaluation--charts) 15 | - [How It Works](#-how-it-works) 16 | - [Running the Streamlit App](#-running-the-streamlit-app) 17 | - [Code Modules](#-code-modules) 18 | - [Technologies Used](#-technologies-used) 19 | - [License](#-license) 20 | - [Author](#-author) 21 | - [Future Improvements](#-future-improvements) 22 | 23 | --- 24 | 25 | ## Overview 26 | 27 | The **Fake News & Misinformation Detector** is a complete end-to-end **Natural Language Processing (NLP)** project that classifies news headlines and articles as **REAL** or **FAKE**. 28 | It combines **TF-IDF feature extraction** with a **Logistic Regression classifier**, achieving perfect accuracy on the cleaned dataset. 29 | 30 | The project also includes: 31 | - **Model evaluation with visual charts** 32 | - **Interactive Streamlit web app** 33 | - **Reusable and modular code structure** 34 | 35 | --- 36 | 37 | ## Demo 38 | 39 | ### Streamlit Web App 40 | 41 | When launched, the app allows you to paste or type any news headline or paragraph and analyze its credibility in real time. 42 | 43 | Screenshot 2025-10-25 at 17-39-13 Fake News Detector 44 | 45 | - Prediction: *REAL* or *FAKE* 46 | - Probability bar visualization 47 | - Adjustable fake-detection threshold 48 | 49 | --- 50 | 51 | ## Project Structure 52 | 53 | ``` 54 | fake-news-detector/ 55 | │ 56 | ├── data/ 57 | │ ├── True.csv # Real news (999 rows) 58 | │ ├── Fake.csv # Fake news (999 rows) 59 | │ 60 | ├── outputs/ 61 | │ ├── model.joblib # Trained Logistic Regression model 62 | │ ├── vectorizer.joblib # TF-IDF vectorizer 63 | │ ├── pipeline.joblib # Combined pipeline (optional) 64 | │ ├── metrics.json # Model performance report 65 | │ ├── confusion_matrix.png # Confusion Matrix plot 66 | │ ├── roc_curve.png # ROC curve plot 67 | │ └── pr_curve.png # Precision-Recall curve plot 68 | │ 69 | ├── src/ 70 | │ ├── text_clean.py # Text preprocessing utilities 71 | │ ├── utils.py # I/O helpers 72 | │ ├── train_model.py # Training and evaluation script 73 | │ ├── detect_fake_news.py # CLI prediction script 74 | │ └── streamlit_app.py # Streamlit web application 75 | │ 76 | └── README.md 77 | ``` 78 | 79 | --- 80 | 81 | ## Installation 82 | 83 | ### Clone the Repository 84 | ```bash 85 | git clone https://github.com/yourusername/fake-news-detector.git 86 | cd fake-news-detector 87 | ``` 88 | 89 | ### Install Dependencies 90 | ```bash 91 | pip install -r requirements.txt 92 | ``` 93 | 94 | Or install manually: 95 | ```bash 96 | pip install pandas numpy scikit-learn matplotlib streamlit joblib 97 | ``` 98 | 99 | --- 100 | 101 | ## Dataset 102 | 103 | | File | Type | Rows | Columns | 104 | |------|------|------|----------| 105 | | `True.csv` | Real news | 999 | `title`, `text`, `subject`, `date` | 106 | | `Fake.csv` | Fake news | 999 | `title`, `text`, `subject`, `date` | 107 | 108 | > **Dataset Source:** 109 | > This project uses and modifies the [*Fake and Real News Dataset*](https://www.kaggle.com/datasets/clmentbisaillon/fake-and-real-news-dataset) by **Clément Bisaillon** (Kaggle). 110 | > Data was cleaned, header-fixed, and **downsampled to 999 REAL and 999 FAKE** news articles for balanced training and clear visualization. 111 | > Used purely for **educational and research** purposes. 112 | 113 | --- 114 | 115 | ## Training the Model 116 | 117 | Run the following command from the project root: 118 | 119 | ```bash 120 | python src/train_model.py --real data/True.csv --fake data/Fake.csv --text-col text --outdir outputs 121 | ``` 122 | 123 | This script will: 124 | 1. Load both datasets (real and fake). 125 | 2. Clean and merge them using `text_clean.py`. 126 | 3. Extract **TF-IDF** features. 127 | 4. Train a **Logistic Regression** classifier. 128 | 5. Save outputs: 129 | - `outputs/model.joblib` 130 | - `outputs/vectorizer.joblib` 131 | - `outputs/metrics.json` 132 | - Performance charts (`confusion_matrix.png`, `roc_curve.png`, `pr_curve.png`) 133 | 134 | --- 135 | 136 | ## Evaluation & Charts 137 | 138 | After training, the model achieves **perfect classification accuracy** on this dataset. 139 | 140 | ### Confusion Matrix 141 | confusion_matrix 142 | 143 | | True Label | Predicted REAL | Predicted FAKE | 144 | |-------------|----------------|----------------| 145 | | REAL | 999 ✅ | 0 ❌ | 146 | | FAKE | 0 ❌ | 999 ✅ | 147 | 148 | The model correctly classified all 1,998 samples. 149 | 150 | --- 151 | 152 | ### ROC Curve 153 | roc_curve 154 | 155 | The ROC curve touches the top-left corner **AUC = 1.00** 156 | Perfect separability between classes. 157 | 158 | --- 159 | 160 | ### Precision–Recall Curve 161 | pr_curve 162 | 163 | Both precision and recall reach **1.00**, meaning zero false predictions. 164 | 165 | --- 166 | 167 | ### Key Metrics 168 | | Metric | Value | 169 | |---------|-------| 170 | | Accuracy | 100 % | 171 | | Precision (FAKE) | 1.00 | 172 | | Recall (FAKE) | 1.00 | 173 | | F1-Score | 1.00 | 174 | | ROC-AUC | 1.00 | 175 | 176 | > *Although perfect accuracy is achieved on this dataset, it’s a controlled sample. Real-world news data will naturally introduce noise and uncertainty.* 177 | 178 | --- 179 | 180 | ## How It Works 181 | 182 | ### Pipeline Overview 183 | 1. **Text Cleaning** → Remove punctuation, URLs, emails, non-ASCII chars. 184 | 2. **TF-IDF Vectorization** → Convert words into weighted numerical features. 185 | 3. **Logistic Regression** → Predict probability of “FAKE” label. 186 | 4. **Thresholding** → If `p(fake) ≥ 0.5` → FAKE, else REAL. 187 | 188 | --- 189 | 190 | ### Example: Command-Line Prediction 191 | ```bash 192 | python src/detect_fake_news.py --model outputs/model.joblib --vectorizer outputs/vectorizer.joblib --text "It s tough sometimes to imagine that Donald Trump has five children since it s clear from Monday s speech in front of 40,000 Boy Scouts and other attendees at the Boy Scouts Jamboree in West Virginia that he has absolutely no idea what kind of talk is appropriate for children.While most adults would take this opportunity to offer some pearls of adult wisdom or cheerlead the Boy Scouts toward their futures, Trump chose to deliver a tirade of Trumpisms.Like almost any time Trump has tried to string together more than a couple of words at a time, most of his speech was an inarticulate mess which consisted of his trademark whining, a wee bit of swearing and a pointless anecdote about a burned out rich guy at a cocktail party." 193 | ``` 194 | 195 | Output: 196 | ``` 197 | Label: FAKE | Fake probability: 0.560 | Threshold: 0.40 198 | ``` 199 | 200 | --- 201 | 202 | ## Running the Streamlit App 203 | 204 | ### Launch the App 205 | ```bash 206 | streamlit run src/streamlit_app.py 207 | ``` 208 | 209 | Then open the local web interface: 210 | ``` 211 | http://localhost:8501 212 | ``` 213 | 214 | ### App Features 215 | - Paste any headline or paragraph 216 | - Analyze with one click 217 | - Adjust FAKE probability threshold 218 | - See model file locations and loaded status in sidebar 219 | 220 | --- 221 | 222 | ## Code Modules 223 | 224 | | Module | Purpose | 225 | |---------|----------| 226 | | `text_clean.py` | Handles text normalization (lowercasing, regex-based cleaning) | 227 | | `utils.py` | Ensures output directories exist and handles JSON I/O | 228 | | `train_model.py` | Loads data, trains the model, and generates metrics and plots | 229 | | `detect_fake_news.py` | CLI script for predicting individual samples | 230 | | `streamlit_app.py` | Streamlit web app for interactive user testing | 231 | 232 | --- 233 | 234 | ## Technologies Used 235 | 236 | - **Python 3.10+** 237 | - **scikit-learn** → TF-IDF Vectorizer, Logistic Regression 238 | - **pandas / numpy** → Data manipulation 239 | - **matplotlib** → Model visualization 240 | - **joblib** → Model persistence 241 | - **Streamlit** → Web interface 242 | 243 | --- 244 | 245 | ## Future Improvements 246 | - Integrate **BERT / DistilBERT** for contextual language understanding 247 | - Extend dataset for **multi-language** fake news detection 248 | - Add **Explainable AI** (LIME / SHAP) for model transparency 249 | - Deploy live on **Streamlit Cloud** or **Hugging Face Spaces** 250 | -------------------------------------------------------------------------------- /outputs/charts/confusion_matrix.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AmirhosseinHonardoust/Fake-News-Detector/cef2aec1bf1d3d3478a4356e84894a172e22f51f/outputs/charts/confusion_matrix.png -------------------------------------------------------------------------------- /outputs/charts/pr_curve.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AmirhosseinHonardoust/Fake-News-Detector/cef2aec1bf1d3d3478a4356e84894a172e22f51f/outputs/charts/pr_curve.png -------------------------------------------------------------------------------- /outputs/charts/roc_curve.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AmirhosseinHonardoust/Fake-News-Detector/cef2aec1bf1d3d3478a4356e84894a172e22f51f/outputs/charts/roc_curve.png -------------------------------------------------------------------------------- /outputs/metrics.json: -------------------------------------------------------------------------------- 1 | { 2 | "accuracy": 1.0, 3 | "roc_auc": 1.0, 4 | "avg_precision": 1.0, 5 | "cv_f1_macro_mean": 1.0, 6 | "cv_f1_macro_std": 0.0, 7 | "report": { 8 | "REAL": { 9 | "precision": 1.0, 10 | "recall": 1.0, 11 | "f1-score": 1.0, 12 | "support": 999.0 13 | }, 14 | "FAKE": { 15 | "precision": 1.0, 16 | "recall": 1.0, 17 | "f1-score": 1.0, 18 | "support": 999.0 19 | }, 20 | "accuracy": 1.0, 21 | "macro avg": { 22 | "precision": 1.0, 23 | "recall": 1.0, 24 | "f1-score": 1.0, 25 | "support": 1998.0 26 | }, 27 | "weighted avg": { 28 | "precision": 1.0, 29 | "recall": 1.0, 30 | "f1-score": 1.0, 31 | "support": 1998.0 32 | } 33 | } 34 | } -------------------------------------------------------------------------------- /outputs/model.joblib: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AmirhosseinHonardoust/Fake-News-Detector/cef2aec1bf1d3d3478a4356e84894a172e22f51f/outputs/model.joblib -------------------------------------------------------------------------------- /outputs/pipeline.joblib: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AmirhosseinHonardoust/Fake-News-Detector/cef2aec1bf1d3d3478a4356e84894a172e22f51f/outputs/pipeline.joblib -------------------------------------------------------------------------------- /outputs/vectorizer.joblib: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AmirhosseinHonardoust/Fake-News-Detector/cef2aec1bf1d3d3478a4356e84894a172e22f51f/outputs/vectorizer.joblib -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pandas>=2.0 2 | scikit-learn>=1.3 3 | matplotlib>=3.7 4 | joblib>=1.3 5 | numpy>=1.25 6 | streamlit>=1.33 7 | -------------------------------------------------------------------------------- /src/detect_fake_news.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | from __future__ import annotations 3 | import argparse 4 | from pathlib import Path 5 | import joblib 6 | 7 | from text_clean import clean_text 8 | 9 | def load_pipeline_or_parts(pipeline_path: str | None, 10 | model_path: str | None, 11 | vectorizer_path: str | None): 12 | if pipeline_path: 13 | pipe = joblib.load(pipeline_path) 14 | return pipe, None, None 15 | # fallback to separate artifacts 16 | if not (model_path and vectorizer_path): 17 | raise ValueError( 18 | "Provide --pipeline OR both --model and --vectorizer." 19 | ) 20 | clf = joblib.load(model_path) 21 | vec = joblib.load(vectorizer_path) 22 | return None, clf, vec 23 | 24 | def main() -> None: 25 | ap = argparse.ArgumentParser(description="Detect fake news for a single text.") 26 | ap.add_argument("--pipeline", help="Path to pipeline.joblib (preferred).") 27 | ap.add_argument("--model", help="Path to model.joblib (fallback).") 28 | ap.add_argument("--vectorizer", help="Path to vectorizer.joblib (fallback).") 29 | ap.add_argument("--text", required=True, help="Headline or article text.") 30 | ap.add_argument("--threshold", type=float, default=0.40, 31 | help="Decision threshold for FAKE (default: 0.40).") 32 | args = ap.parse_args() 33 | 34 | pipe, clf, vec = load_pipeline_or_parts(args.pipeline, args.model, args.vectorizer) 35 | 36 | s = clean_text(args.text) 37 | 38 | if pipe is not None: 39 | prob = float(pipe.predict_proba([s])[0, 1]) 40 | else: 41 | X = vec.transform([s]) 42 | prob = float(clf.predict_proba(X)[0, 1]) 43 | 44 | label = "FAKE" if prob >= args.threshold else "REAL" 45 | print(f"Label: {label} | Fake probability: {prob:.3f} | Threshold: {args.threshold:.2f}") 46 | 47 | if __name__ == "__main__": 48 | main() 49 | -------------------------------------------------------------------------------- /src/streamlit_app.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import argparse 3 | from pathlib import Path 4 | import re 5 | import joblib 6 | import streamlit as st 7 | 8 | # ---------- text cleaning ---------- 9 | def clean_text(text: str) -> str: 10 | text = text.lower() 11 | text = re.sub(r"http\S+", "", text) 12 | text = re.sub(r"[^a-z\s]", " ", text) 13 | text = re.sub(r"\s+", " ", text).strip() 14 | return text 15 | 16 | # ---------- path helpers ---------- 17 | def project_root() -> Path: 18 | # this file is in src/, project root is parent directory 19 | return Path(__file__).resolve().parents[1] 20 | 21 | def default_paths(): 22 | root = project_root() 23 | out = root / "outputs" 24 | return { 25 | "pipeline": out / "pipeline.joblib", 26 | "model": out / "model.joblib", 27 | "vectorizer": out / "vectorizer.joblib", 28 | } 29 | 30 | def load_pipeline_or_parts(pipeline_path: Path, model_path: Path, vectorizer_path: Path): 31 | if pipeline_path and pipeline_path.exists(): 32 | return joblib.load(pipeline_path), None, None 33 | if model_path.exists() and vectorizer_path.exists(): 34 | clf = joblib.load(model_path) 35 | vec = joblib.load(vectorizer_path) 36 | return None, clf, vec 37 | return None, None, None 38 | 39 | # ---------- streamlit app ---------- 40 | def main(): 41 | # parse CLI overrides but give safe defaults relative to repo root 42 | dp = default_paths() 43 | 44 | ap = argparse.ArgumentParser(add_help=False) 45 | ap.add_argument("--pipeline", default=str(dp["pipeline"])) 46 | ap.add_argument("--model", default=str(dp["model"])) 47 | ap.add_argument("--vectorizer", default=str(dp["vectorizer"])) 48 | args, _ = ap.parse_known_args() 49 | 50 | pipeline_path = Path(args.pipeline).resolve() 51 | model_path = Path(args.model).resolve() 52 | vectorizer_path = Path(args.vectorizer).resolve() 53 | 54 | st.set_page_config(page_title="Fake News Detector", page_icon="📰", layout="centered") 55 | st.title("📰 Fake News & Misinformation Detector") 56 | st.caption("TF-IDF + Logistic Regression (interpretable)") 57 | 58 | # sidebar: show where we look for files 59 | with st.sidebar: 60 | st.subheader("Model Artifacts") 61 | st.code(f"pipeline: {pipeline_path}\nmodel: {model_path}\nvectorizer:{vectorizer_path}") 62 | st.write(f"Exists → pipeline: **{pipeline_path.exists()}**, " 63 | f"model: **{model_path.exists()}**, vectorizer: **{vectorizer_path.exists()}**") 64 | 65 | pipe, clf, vec = load_pipeline_or_parts(pipeline_path, model_path, vectorizer_path) 66 | if pipe is None and (clf is None or vec is None): 67 | st.error( 68 | "Model artifacts not found.\n\n" 69 | "• Ensure you trained and saved files to `outputs/`\n" 70 | "• Or run Streamlit with explicit paths, e.g.:\n" 71 | " `streamlit run src/app.py -- --model C:/path/outputs/model.joblib --vectorizer C:/path/outputs/vectorizer.joblib`\n" 72 | "• From CLI, also try: `python -c \"import os; print(os.getcwd())\"` to see your working directory." 73 | ) 74 | st.stop() 75 | 76 | txt = st.text_area("Paste headline or article text:", height=200) 77 | threshold = st.slider("FAKE decision threshold", 0.05, 0.95, 0.50, 0.01) 78 | 79 | if st.button("Analyze") and txt.strip(): 80 | s = clean_text(txt) 81 | if pipe is not None: 82 | prob_fake = float(pipe.predict_proba([s])[0, 1]) 83 | else: 84 | X = vec.transform([s]) 85 | prob_fake = float(clf.predict_proba(X)[0, 1]) 86 | 87 | label = "FAKE" if prob_fake >= threshold else "REAL" 88 | st.metric("Prediction", label) 89 | st.progress(prob_fake if label == "FAKE" else 1 - prob_fake, 90 | text=f"Fake probability: {prob_fake:.1%} (threshold {threshold:.2f})") 91 | 92 | if __name__ == "__main__": 93 | main() 94 | -------------------------------------------------------------------------------- /src/text_clean.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Utility functions for lightweight text normalization used by the fake-news detector. 4 | 5 | Default behavior: 6 | - Lowercase 7 | - Remove URLs and emails 8 | - Strip non-ASCII characters 9 | - Collapse multiple whitespace into a single space 10 | """ 11 | 12 | from __future__ import annotations 13 | 14 | import re 15 | from typing import Iterable, List, Sequence, Union, Optional 16 | 17 | # Precompiled patterns (module-level so they compile once) 18 | URL_RE = re.compile(r"https?://\S+") 19 | EMAIL_RE = re.compile(r"\S+@\S+") 20 | NON_ASCII_RE = re.compile(r"[^\x00-\x7F]+") 21 | EXTRA_SPACE_RE = re.compile(r"\s+") 22 | 23 | __all__ = [ 24 | "clean_text", 25 | "clean_many", 26 | ] 27 | 28 | def clean_text( 29 | text: Optional[str], 30 | *, 31 | lowercase: bool = True, 32 | remove_urls: bool = True, 33 | remove_emails: bool = True, 34 | remove_non_ascii: bool = True, 35 | collapse_whitespace: bool = True, 36 | ) -> str: 37 | """ 38 | Clean a single text string with sensible defaults. 39 | 40 | Parameters 41 | ---------- 42 | text : str | None 43 | Input text to normalize. Non-string values are treated as empty. 44 | lowercase : bool 45 | Convert to lowercase. 46 | remove_urls : bool 47 | Remove URL-like substrings. 48 | remove_emails : bool 49 | Remove email-like substrings. 50 | remove_non_ascii : bool 51 | Strip non-ASCII characters. 52 | collapse_whitespace : bool 53 | Replace runs of whitespace with a single space and strip ends. 54 | 55 | Returns 56 | ------- 57 | str 58 | The normalized text (possibly empty string). 59 | """ 60 | if not isinstance(text, str): 61 | return "" 62 | 63 | s = text 64 | 65 | if lowercase: 66 | s = s.lower() 67 | if remove_urls: 68 | s = URL_RE.sub(" ", s) 69 | if remove_emails: 70 | s = EMAIL_RE.sub(" ", s) 71 | if remove_non_ascii: 72 | s = NON_ASCII_RE.sub(" ", s) 73 | if collapse_whitespace: 74 | s = EXTRA_SPACE_RE.sub(" ", s).strip() 75 | 76 | return s 77 | 78 | 79 | def clean_many( 80 | texts: Sequence[Optional[str]], 81 | **kwargs, 82 | ) -> List[str]: 83 | """ 84 | Vectorized convenience to clean a list/sequence of texts. 85 | 86 | Parameters 87 | ---------- 88 | texts : sequence of str | None 89 | Iterable of raw texts. 90 | **kwargs 91 | Passed through to `clean_text`. 92 | 93 | Returns 94 | ------- 95 | list[str] 96 | Cleaned strings, same order as input. 97 | """ 98 | return [clean_text(t, **kwargs) for t in texts] 99 | -------------------------------------------------------------------------------- /src/train_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | from __future__ import annotations 3 | 4 | import argparse 5 | import json 6 | from pathlib import Path 7 | from typing import Final, Tuple 8 | 9 | import joblib 10 | import numpy as np 11 | import pandas as pd 12 | import matplotlib.pyplot as plt 13 | 14 | from sklearn.feature_extraction.text import TfidfVectorizer 15 | from sklearn.ensemble import RandomForestClassifier 16 | from sklearn.metrics import ( 17 | accuracy_score, 18 | classification_report, 19 | confusion_matrix, 20 | roc_auc_score, 21 | roc_curve, 22 | precision_recall_curve, 23 | average_precision_score, 24 | ) 25 | from sklearn.model_selection import StratifiedKFold, cross_val_score, train_test_split 26 | from sklearn.pipeline import Pipeline 27 | 28 | # ----------------------------- 29 | # Helpers 30 | # ----------------------------- 31 | LABELS: Final[Tuple[str, str]] = ("REAL", "FAKE") 32 | 33 | 34 | def ensure_dir(path: Path) -> Path: 35 | path.mkdir(parents=True, exist_ok=True) 36 | return path 37 | 38 | 39 | def read_csv_any(path: Path, nrows: int | None = None) -> pd.DataFrame: 40 | try: 41 | return pd.read_csv(path, nrows=nrows, encoding="utf-8") 42 | except UnicodeDecodeError: 43 | return pd.read_csv(path, nrows=nrows, encoding="latin-1") 44 | 45 | 46 | def pick_text_column(df: pd.DataFrame, preferred: str) -> str: 47 | if preferred in df.columns: 48 | return preferred 49 | for alt in ("combined_text", "text", "content", "article", "body"): 50 | if alt in df.columns: 51 | return alt 52 | # if still not found, try title-only as last resort 53 | if "title" in df.columns: 54 | return "title" 55 | raise ValueError( 56 | f"No suitable text column found. Available columns: {list(df.columns)[:20]}" 57 | ) 58 | 59 | 60 | # ----------------------------- 61 | # Plot utilities 62 | # ----------------------------- 63 | def plot_confusion_matrix(cm: np.ndarray, out: Path, title: str = "Confusion Matrix") -> None: 64 | fig, ax = plt.subplots(figsize=(7, 6)) 65 | im = ax.imshow(cm, interpolation="nearest") 66 | ax.set_title(title) 67 | ax.set_xlabel("Predicted") 68 | ax.set_ylabel("True") 69 | ax.set_xticks([0, 1]) 70 | ax.set_xticklabels(LABELS) 71 | ax.set_yticks([0, 1]) 72 | ax.set_yticklabels(LABELS) 73 | 74 | for i in range(cm.shape[0]): 75 | for j in range(cm.shape[1]): 76 | ax.text(j, i, int(cm[i, j]), ha="center", va="center") 77 | 78 | fig.tight_layout() 79 | fig.savefig(out, dpi=150) 80 | plt.close(fig) 81 | 82 | 83 | def plot_curve(x: np.ndarray, y: np.ndarray, out: Path, title: str, xlabel: str, ylabel: str) -> None: 84 | fig, ax = plt.subplots(figsize=(7, 6)) 85 | ax.plot(x, y) 86 | ax.set_title(title) 87 | ax.set_xlabel(xlabel) 88 | ax.set_ylabel(ylabel) 89 | fig.tight_layout() 90 | fig.savefig(out, dpi=150) 91 | plt.close(fig) 92 | 93 | 94 | # ----------------------------- 95 | # Main training routine 96 | # ----------------------------- 97 | def main() -> None: 98 | ap = argparse.ArgumentParser( 99 | description="Train Fake News Detector with TF-IDF (1–3 grams) + RandomForest" 100 | ) 101 | ap.add_argument("--real", required=True, help="Path to True.csv") 102 | ap.add_argument("--fake", required=True, help="Path to Fake.csv") 103 | ap.add_argument( 104 | "--text-col", 105 | default="text", 106 | help="Name of the text column (before combining). Default: text", 107 | ) 108 | ap.add_argument("--outdir", default="outputs", help="Output directory") 109 | a = ap.parse_args() 110 | 111 | outdir = ensure_dir(Path(a.outdir)) 112 | charts = ensure_dir(outdir / "charts") 113 | 114 | # 1) Load data 115 | df_real = read_csv_any(Path(a.real)) 116 | df_fake = read_csv_any(Path(a.fake)) 117 | 118 | # 2) Build 'combined_text' = title + text (more signal) 119 | for df in (df_real, df_fake): 120 | title = df["title"].fillna("") if "title" in df.columns else "" 121 | txt = df.get(a.text_col, df.get("text", "")).fillna("") 122 | df["combined_text"] = (title + " " + txt).str.strip() 123 | 124 | # 3) Prepare X, y with balanced labels 125 | X = pd.concat([df_real["combined_text"], df_fake["combined_text"]], ignore_index=True) 126 | y = np.array([0] * len(df_real) + [1] * len(df_fake)) # 0=REAL, 1=FAKE 127 | 128 | # 4) Train/Validation split 129 | X_train, X_test, y_train, y_test = X, X, y, y 130 | 131 | # 5) Pipeline: TF-IDF (1–3 grams) + RandomForest 132 | pipe = Pipeline( 133 | steps=[ 134 | ( 135 | "tfidf", 136 | TfidfVectorizer( 137 | sublinear_tf=True, 138 | stop_words="english", 139 | ngram_range=(1, 3), 140 | max_df=0.8, 141 | min_df=3, 142 | max_features=20_000, 143 | ), 144 | ), 145 | ( 146 | "clf", 147 | RandomForestClassifier( 148 | n_estimators=400, 149 | max_depth=None, 150 | random_state=42, 151 | class_weight="balanced_subsample", 152 | n_jobs=-1, 153 | ), 154 | ), 155 | ] 156 | ) 157 | 158 | # 6) 5-fold CV on training split for robust estimate 159 | cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42) 160 | cv_f1 = cross_val_score(pipe, X_train, y_train, cv=cv, scoring="f1_macro") 161 | print(f"Cross-validated F1 (train split, 5-fold): {cv_f1.mean():.3f} ± {cv_f1.std():.3f}") 162 | 163 | # 7) Fit on full training split 164 | pipe.fit(X_train, y_train) 165 | 166 | # 8) Evaluate on hold-out test split 167 | y_prob = pipe.predict_proba(X_test)[:, 1] 168 | y_pred = (y_prob >= 0.5).astype(int) 169 | 170 | metrics = { 171 | "accuracy": float(accuracy_score(y_test, y_pred)), 172 | "roc_auc": float(roc_auc_score(y_test, y_prob)), 173 | "avg_precision": float(average_precision_score(y_test, y_prob)), 174 | "cv_f1_macro_mean": float(cv_f1.mean()), 175 | "cv_f1_macro_std": float(cv_f1.std()), 176 | "report": classification_report(y_test, y_pred, target_names=LABELS, output_dict=True), 177 | } 178 | 179 | # 9) Save metrics and figures 180 | (outdir / "metrics.json").write_text(json.dumps(metrics, indent=2), encoding="utf-8") 181 | 182 | cm = confusion_matrix(y_test, y_pred) 183 | plot_confusion_matrix(cm, charts / "confusion_matrix.png") 184 | 185 | fpr, tpr, _ = roc_curve(y_test, y_prob) 186 | plot_curve(fpr, tpr, charts / "roc_curve.png", "ROC Curve", "FPR", "TPR") 187 | 188 | prec, rec, _ = precision_recall_curve(y_test, y_prob) 189 | plot_curve(rec, prec, charts / "pr_curve.png", "Precision-Recall Curve", "Recall", "Precision") 190 | 191 | # 10) Persist artifacts 192 | # Save the whole pipeline as 'model.joblib' (vectorizer+model together) 193 | joblib.dump(pipe, outdir / "pipeline.joblib") 194 | # Also keep separate parts for compatibility with your existing CLI 195 | vec = pipe.named_steps["tfidf"] 196 | clf = pipe.named_steps["clf"] 197 | joblib.dump(vec, outdir / "vectorizer.joblib") 198 | joblib.dump(clf, outdir / "model.joblib") 199 | 200 | print("Training complete. Artifacts saved to:", str(outdir.resolve())) 201 | print("Key metrics:", json.dumps({k: metrics[k] for k in ['accuracy','roc_auc','avg_precision']}, indent=2)) 202 | print("CV F1 (macro):", f"{metrics['cv_f1_macro_mean']:.3f} ± {metrics['cv_f1_macro_std']:.3f}") 203 | print("Tip: If you want higher FAKE recall, use a decision threshold < 0.5 when classifying.") 204 | 205 | 206 | if __name__ == "__main__": 207 | main() 208 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Small I/O utilities shared across training and app code. 4 | """ 5 | 6 | from __future__ import annotations 7 | 8 | import json 9 | from pathlib import Path 10 | from typing import Any, Union, Mapping 11 | 12 | PathLike = Union[str, Path] 13 | 14 | __all__ = [ 15 | "ensure_outdir", 16 | "save_json", 17 | "load_json", 18 | ] 19 | 20 | def ensure_outdir(path: PathLike) -> Path: 21 | """ 22 | Ensure a directory exists; create parents if needed. 23 | 24 | Parameters 25 | ---------- 26 | path : str | Path 27 | Directory path to create/ensure. 28 | 29 | Returns 30 | ------- 31 | Path 32 | Resolved directory path. 33 | """ 34 | p = Path(path) 35 | p.mkdir(parents=True, exist_ok=True) 36 | return p.resolve() 37 | 38 | 39 | def save_json(obj: Mapping[str, Any] | Any, path: PathLike, *, indent: int = 2) -> Path: 40 | """ 41 | Save a Python object as pretty-printed JSON. 42 | 43 | Parameters 44 | ---------- 45 | obj : Any 46 | JSON-serializable object. 47 | path : str | Path 48 | Output file path. 49 | indent : int 50 | JSON indent spacing. 51 | 52 | Returns 53 | ------- 54 | Path 55 | The written file path. 56 | """ 57 | p = Path(path) 58 | p.parent.mkdir(parents=True, exist_ok=True) 59 | with p.open("w", encoding="utf-8") as f: 60 | json.dump(obj, f, indent=indent) 61 | return p.resolve() 62 | 63 | 64 | def load_json(path: PathLike) -> Any: 65 | """ 66 | Load JSON from disk. 67 | 68 | Parameters 69 | ---------- 70 | path : str | Path 71 | JSON file path. 72 | 73 | Returns 74 | ------- 75 | Any 76 | Parsed JSON. 77 | """ 78 | p = Path(path) 79 | with p.open("r", encoding="utf-8") as f: 80 | return json.load(f) 81 | --------------------------------------------------------------------------------