├── LICENSE
├── README.md
├── data
├── Fake.csv
└── True.csv
├── outputs
├── charts
│ ├── confusion_matrix.png
│ ├── pr_curve.png
│ └── roc_curve.png
├── metrics.json
├── model.joblib
├── pipeline.joblib
└── vectorizer.joblib
├── requirements.txt
└── src
├── detect_fake_news.py
├── streamlit_app.py
├── text_clean.py
├── train_model.py
└── utils.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 Amir
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Fake News & Misinformation Detector
2 |
3 | Detect fake vs real news articles using Machine Learning, TF-IDF, and Logistic Regression, complete with training scripts, evaluation charts, and an interactive Streamlit web app.
4 |
5 | ---
6 |
7 | ## Table of Contents
8 | - [Overview](#-overview)
9 | - [Demo](#-demo)
10 | - [Project Structure](#-project-structure)
11 | - [Installation](#-installation)
12 | - [Dataset](#-dataset)
13 | - [Training the Model](#-training-the-model)
14 | - [Evaluation & Charts](#-evaluation--charts)
15 | - [How It Works](#-how-it-works)
16 | - [Running the Streamlit App](#-running-the-streamlit-app)
17 | - [Code Modules](#-code-modules)
18 | - [Technologies Used](#-technologies-used)
19 | - [License](#-license)
20 | - [Author](#-author)
21 | - [Future Improvements](#-future-improvements)
22 |
23 | ---
24 |
25 | ## Overview
26 |
27 | The **Fake News & Misinformation Detector** is a complete end-to-end **Natural Language Processing (NLP)** project that classifies news headlines and articles as **REAL** or **FAKE**.
28 | It combines **TF-IDF feature extraction** with a **Logistic Regression classifier**, achieving perfect accuracy on the cleaned dataset.
29 |
30 | The project also includes:
31 | - **Model evaluation with visual charts**
32 | - **Interactive Streamlit web app**
33 | - **Reusable and modular code structure**
34 |
35 | ---
36 |
37 | ## Demo
38 |
39 | ### Streamlit Web App
40 |
41 | When launched, the app allows you to paste or type any news headline or paragraph and analyze its credibility in real time.
42 |
43 |
44 |
45 | - Prediction: *REAL* or *FAKE*
46 | - Probability bar visualization
47 | - Adjustable fake-detection threshold
48 |
49 | ---
50 |
51 | ## Project Structure
52 |
53 | ```
54 | fake-news-detector/
55 | │
56 | ├── data/
57 | │ ├── True.csv # Real news (999 rows)
58 | │ ├── Fake.csv # Fake news (999 rows)
59 | │
60 | ├── outputs/
61 | │ ├── model.joblib # Trained Logistic Regression model
62 | │ ├── vectorizer.joblib # TF-IDF vectorizer
63 | │ ├── pipeline.joblib # Combined pipeline (optional)
64 | │ ├── metrics.json # Model performance report
65 | │ ├── confusion_matrix.png # Confusion Matrix plot
66 | │ ├── roc_curve.png # ROC curve plot
67 | │ └── pr_curve.png # Precision-Recall curve plot
68 | │
69 | ├── src/
70 | │ ├── text_clean.py # Text preprocessing utilities
71 | │ ├── utils.py # I/O helpers
72 | │ ├── train_model.py # Training and evaluation script
73 | │ ├── detect_fake_news.py # CLI prediction script
74 | │ └── streamlit_app.py # Streamlit web application
75 | │
76 | └── README.md
77 | ```
78 |
79 | ---
80 |
81 | ## Installation
82 |
83 | ### Clone the Repository
84 | ```bash
85 | git clone https://github.com/yourusername/fake-news-detector.git
86 | cd fake-news-detector
87 | ```
88 |
89 | ### Install Dependencies
90 | ```bash
91 | pip install -r requirements.txt
92 | ```
93 |
94 | Or install manually:
95 | ```bash
96 | pip install pandas numpy scikit-learn matplotlib streamlit joblib
97 | ```
98 |
99 | ---
100 |
101 | ## Dataset
102 |
103 | | File | Type | Rows | Columns |
104 | |------|------|------|----------|
105 | | `True.csv` | Real news | 999 | `title`, `text`, `subject`, `date` |
106 | | `Fake.csv` | Fake news | 999 | `title`, `text`, `subject`, `date` |
107 |
108 | > **Dataset Source:**
109 | > This project uses and modifies the [*Fake and Real News Dataset*](https://www.kaggle.com/datasets/clmentbisaillon/fake-and-real-news-dataset) by **Clément Bisaillon** (Kaggle).
110 | > Data was cleaned, header-fixed, and **downsampled to 999 REAL and 999 FAKE** news articles for balanced training and clear visualization.
111 | > Used purely for **educational and research** purposes.
112 |
113 | ---
114 |
115 | ## Training the Model
116 |
117 | Run the following command from the project root:
118 |
119 | ```bash
120 | python src/train_model.py --real data/True.csv --fake data/Fake.csv --text-col text --outdir outputs
121 | ```
122 |
123 | This script will:
124 | 1. Load both datasets (real and fake).
125 | 2. Clean and merge them using `text_clean.py`.
126 | 3. Extract **TF-IDF** features.
127 | 4. Train a **Logistic Regression** classifier.
128 | 5. Save outputs:
129 | - `outputs/model.joblib`
130 | - `outputs/vectorizer.joblib`
131 | - `outputs/metrics.json`
132 | - Performance charts (`confusion_matrix.png`, `roc_curve.png`, `pr_curve.png`)
133 |
134 | ---
135 |
136 | ## Evaluation & Charts
137 |
138 | After training, the model achieves **perfect classification accuracy** on this dataset.
139 |
140 | ### Confusion Matrix
141 |
142 |
143 | | True Label | Predicted REAL | Predicted FAKE |
144 | |-------------|----------------|----------------|
145 | | REAL | 999 ✅ | 0 ❌ |
146 | | FAKE | 0 ❌ | 999 ✅ |
147 |
148 | The model correctly classified all 1,998 samples.
149 |
150 | ---
151 |
152 | ### ROC Curve
153 |
154 |
155 | The ROC curve touches the top-left corner **AUC = 1.00**
156 | Perfect separability between classes.
157 |
158 | ---
159 |
160 | ### Precision–Recall Curve
161 |
162 |
163 | Both precision and recall reach **1.00**, meaning zero false predictions.
164 |
165 | ---
166 |
167 | ### Key Metrics
168 | | Metric | Value |
169 | |---------|-------|
170 | | Accuracy | 100 % |
171 | | Precision (FAKE) | 1.00 |
172 | | Recall (FAKE) | 1.00 |
173 | | F1-Score | 1.00 |
174 | | ROC-AUC | 1.00 |
175 |
176 | > *Although perfect accuracy is achieved on this dataset, it’s a controlled sample. Real-world news data will naturally introduce noise and uncertainty.*
177 |
178 | ---
179 |
180 | ## How It Works
181 |
182 | ### Pipeline Overview
183 | 1. **Text Cleaning** → Remove punctuation, URLs, emails, non-ASCII chars.
184 | 2. **TF-IDF Vectorization** → Convert words into weighted numerical features.
185 | 3. **Logistic Regression** → Predict probability of “FAKE” label.
186 | 4. **Thresholding** → If `p(fake) ≥ 0.5` → FAKE, else REAL.
187 |
188 | ---
189 |
190 | ### Example: Command-Line Prediction
191 | ```bash
192 | python src/detect_fake_news.py --model outputs/model.joblib --vectorizer outputs/vectorizer.joblib --text "It s tough sometimes to imagine that Donald Trump has five children since it s clear from Monday s speech in front of 40,000 Boy Scouts and other attendees at the Boy Scouts Jamboree in West Virginia that he has absolutely no idea what kind of talk is appropriate for children.While most adults would take this opportunity to offer some pearls of adult wisdom or cheerlead the Boy Scouts toward their futures, Trump chose to deliver a tirade of Trumpisms.Like almost any time Trump has tried to string together more than a couple of words at a time, most of his speech was an inarticulate mess which consisted of his trademark whining, a wee bit of swearing and a pointless anecdote about a burned out rich guy at a cocktail party."
193 | ```
194 |
195 | Output:
196 | ```
197 | Label: FAKE | Fake probability: 0.560 | Threshold: 0.40
198 | ```
199 |
200 | ---
201 |
202 | ## Running the Streamlit App
203 |
204 | ### Launch the App
205 | ```bash
206 | streamlit run src/streamlit_app.py
207 | ```
208 |
209 | Then open the local web interface:
210 | ```
211 | http://localhost:8501
212 | ```
213 |
214 | ### App Features
215 | - Paste any headline or paragraph
216 | - Analyze with one click
217 | - Adjust FAKE probability threshold
218 | - See model file locations and loaded status in sidebar
219 |
220 | ---
221 |
222 | ## Code Modules
223 |
224 | | Module | Purpose |
225 | |---------|----------|
226 | | `text_clean.py` | Handles text normalization (lowercasing, regex-based cleaning) |
227 | | `utils.py` | Ensures output directories exist and handles JSON I/O |
228 | | `train_model.py` | Loads data, trains the model, and generates metrics and plots |
229 | | `detect_fake_news.py` | CLI script for predicting individual samples |
230 | | `streamlit_app.py` | Streamlit web app for interactive user testing |
231 |
232 | ---
233 |
234 | ## Technologies Used
235 |
236 | - **Python 3.10+**
237 | - **scikit-learn** → TF-IDF Vectorizer, Logistic Regression
238 | - **pandas / numpy** → Data manipulation
239 | - **matplotlib** → Model visualization
240 | - **joblib** → Model persistence
241 | - **Streamlit** → Web interface
242 |
243 | ---
244 |
245 | ## Future Improvements
246 | - Integrate **BERT / DistilBERT** for contextual language understanding
247 | - Extend dataset for **multi-language** fake news detection
248 | - Add **Explainable AI** (LIME / SHAP) for model transparency
249 | - Deploy live on **Streamlit Cloud** or **Hugging Face Spaces**
250 |
--------------------------------------------------------------------------------
/outputs/charts/confusion_matrix.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AmirhosseinHonardoust/Fake-News-Detector/cef2aec1bf1d3d3478a4356e84894a172e22f51f/outputs/charts/confusion_matrix.png
--------------------------------------------------------------------------------
/outputs/charts/pr_curve.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AmirhosseinHonardoust/Fake-News-Detector/cef2aec1bf1d3d3478a4356e84894a172e22f51f/outputs/charts/pr_curve.png
--------------------------------------------------------------------------------
/outputs/charts/roc_curve.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AmirhosseinHonardoust/Fake-News-Detector/cef2aec1bf1d3d3478a4356e84894a172e22f51f/outputs/charts/roc_curve.png
--------------------------------------------------------------------------------
/outputs/metrics.json:
--------------------------------------------------------------------------------
1 | {
2 | "accuracy": 1.0,
3 | "roc_auc": 1.0,
4 | "avg_precision": 1.0,
5 | "cv_f1_macro_mean": 1.0,
6 | "cv_f1_macro_std": 0.0,
7 | "report": {
8 | "REAL": {
9 | "precision": 1.0,
10 | "recall": 1.0,
11 | "f1-score": 1.0,
12 | "support": 999.0
13 | },
14 | "FAKE": {
15 | "precision": 1.0,
16 | "recall": 1.0,
17 | "f1-score": 1.0,
18 | "support": 999.0
19 | },
20 | "accuracy": 1.0,
21 | "macro avg": {
22 | "precision": 1.0,
23 | "recall": 1.0,
24 | "f1-score": 1.0,
25 | "support": 1998.0
26 | },
27 | "weighted avg": {
28 | "precision": 1.0,
29 | "recall": 1.0,
30 | "f1-score": 1.0,
31 | "support": 1998.0
32 | }
33 | }
34 | }
--------------------------------------------------------------------------------
/outputs/model.joblib:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AmirhosseinHonardoust/Fake-News-Detector/cef2aec1bf1d3d3478a4356e84894a172e22f51f/outputs/model.joblib
--------------------------------------------------------------------------------
/outputs/pipeline.joblib:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AmirhosseinHonardoust/Fake-News-Detector/cef2aec1bf1d3d3478a4356e84894a172e22f51f/outputs/pipeline.joblib
--------------------------------------------------------------------------------
/outputs/vectorizer.joblib:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AmirhosseinHonardoust/Fake-News-Detector/cef2aec1bf1d3d3478a4356e84894a172e22f51f/outputs/vectorizer.joblib
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | pandas>=2.0
2 | scikit-learn>=1.3
3 | matplotlib>=3.7
4 | joblib>=1.3
5 | numpy>=1.25
6 | streamlit>=1.33
7 |
--------------------------------------------------------------------------------
/src/detect_fake_news.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | from __future__ import annotations
3 | import argparse
4 | from pathlib import Path
5 | import joblib
6 |
7 | from text_clean import clean_text
8 |
9 | def load_pipeline_or_parts(pipeline_path: str | None,
10 | model_path: str | None,
11 | vectorizer_path: str | None):
12 | if pipeline_path:
13 | pipe = joblib.load(pipeline_path)
14 | return pipe, None, None
15 | # fallback to separate artifacts
16 | if not (model_path and vectorizer_path):
17 | raise ValueError(
18 | "Provide --pipeline OR both --model and --vectorizer."
19 | )
20 | clf = joblib.load(model_path)
21 | vec = joblib.load(vectorizer_path)
22 | return None, clf, vec
23 |
24 | def main() -> None:
25 | ap = argparse.ArgumentParser(description="Detect fake news for a single text.")
26 | ap.add_argument("--pipeline", help="Path to pipeline.joblib (preferred).")
27 | ap.add_argument("--model", help="Path to model.joblib (fallback).")
28 | ap.add_argument("--vectorizer", help="Path to vectorizer.joblib (fallback).")
29 | ap.add_argument("--text", required=True, help="Headline or article text.")
30 | ap.add_argument("--threshold", type=float, default=0.40,
31 | help="Decision threshold for FAKE (default: 0.40).")
32 | args = ap.parse_args()
33 |
34 | pipe, clf, vec = load_pipeline_or_parts(args.pipeline, args.model, args.vectorizer)
35 |
36 | s = clean_text(args.text)
37 |
38 | if pipe is not None:
39 | prob = float(pipe.predict_proba([s])[0, 1])
40 | else:
41 | X = vec.transform([s])
42 | prob = float(clf.predict_proba(X)[0, 1])
43 |
44 | label = "FAKE" if prob >= args.threshold else "REAL"
45 | print(f"Label: {label} | Fake probability: {prob:.3f} | Threshold: {args.threshold:.2f}")
46 |
47 | if __name__ == "__main__":
48 | main()
49 |
--------------------------------------------------------------------------------
/src/streamlit_app.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | import argparse
3 | from pathlib import Path
4 | import re
5 | import joblib
6 | import streamlit as st
7 |
8 | # ---------- text cleaning ----------
9 | def clean_text(text: str) -> str:
10 | text = text.lower()
11 | text = re.sub(r"http\S+", "", text)
12 | text = re.sub(r"[^a-z\s]", " ", text)
13 | text = re.sub(r"\s+", " ", text).strip()
14 | return text
15 |
16 | # ---------- path helpers ----------
17 | def project_root() -> Path:
18 | # this file is in src/, project root is parent directory
19 | return Path(__file__).resolve().parents[1]
20 |
21 | def default_paths():
22 | root = project_root()
23 | out = root / "outputs"
24 | return {
25 | "pipeline": out / "pipeline.joblib",
26 | "model": out / "model.joblib",
27 | "vectorizer": out / "vectorizer.joblib",
28 | }
29 |
30 | def load_pipeline_or_parts(pipeline_path: Path, model_path: Path, vectorizer_path: Path):
31 | if pipeline_path and pipeline_path.exists():
32 | return joblib.load(pipeline_path), None, None
33 | if model_path.exists() and vectorizer_path.exists():
34 | clf = joblib.load(model_path)
35 | vec = joblib.load(vectorizer_path)
36 | return None, clf, vec
37 | return None, None, None
38 |
39 | # ---------- streamlit app ----------
40 | def main():
41 | # parse CLI overrides but give safe defaults relative to repo root
42 | dp = default_paths()
43 |
44 | ap = argparse.ArgumentParser(add_help=False)
45 | ap.add_argument("--pipeline", default=str(dp["pipeline"]))
46 | ap.add_argument("--model", default=str(dp["model"]))
47 | ap.add_argument("--vectorizer", default=str(dp["vectorizer"]))
48 | args, _ = ap.parse_known_args()
49 |
50 | pipeline_path = Path(args.pipeline).resolve()
51 | model_path = Path(args.model).resolve()
52 | vectorizer_path = Path(args.vectorizer).resolve()
53 |
54 | st.set_page_config(page_title="Fake News Detector", page_icon="📰", layout="centered")
55 | st.title("📰 Fake News & Misinformation Detector")
56 | st.caption("TF-IDF + Logistic Regression (interpretable)")
57 |
58 | # sidebar: show where we look for files
59 | with st.sidebar:
60 | st.subheader("Model Artifacts")
61 | st.code(f"pipeline: {pipeline_path}\nmodel: {model_path}\nvectorizer:{vectorizer_path}")
62 | st.write(f"Exists → pipeline: **{pipeline_path.exists()}**, "
63 | f"model: **{model_path.exists()}**, vectorizer: **{vectorizer_path.exists()}**")
64 |
65 | pipe, clf, vec = load_pipeline_or_parts(pipeline_path, model_path, vectorizer_path)
66 | if pipe is None and (clf is None or vec is None):
67 | st.error(
68 | "Model artifacts not found.\n\n"
69 | "• Ensure you trained and saved files to `outputs/`\n"
70 | "• Or run Streamlit with explicit paths, e.g.:\n"
71 | " `streamlit run src/app.py -- --model C:/path/outputs/model.joblib --vectorizer C:/path/outputs/vectorizer.joblib`\n"
72 | "• From CLI, also try: `python -c \"import os; print(os.getcwd())\"` to see your working directory."
73 | )
74 | st.stop()
75 |
76 | txt = st.text_area("Paste headline or article text:", height=200)
77 | threshold = st.slider("FAKE decision threshold", 0.05, 0.95, 0.50, 0.01)
78 |
79 | if st.button("Analyze") and txt.strip():
80 | s = clean_text(txt)
81 | if pipe is not None:
82 | prob_fake = float(pipe.predict_proba([s])[0, 1])
83 | else:
84 | X = vec.transform([s])
85 | prob_fake = float(clf.predict_proba(X)[0, 1])
86 |
87 | label = "FAKE" if prob_fake >= threshold else "REAL"
88 | st.metric("Prediction", label)
89 | st.progress(prob_fake if label == "FAKE" else 1 - prob_fake,
90 | text=f"Fake probability: {prob_fake:.1%} (threshold {threshold:.2f})")
91 |
92 | if __name__ == "__main__":
93 | main()
94 |
--------------------------------------------------------------------------------
/src/text_clean.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """
3 | Utility functions for lightweight text normalization used by the fake-news detector.
4 |
5 | Default behavior:
6 | - Lowercase
7 | - Remove URLs and emails
8 | - Strip non-ASCII characters
9 | - Collapse multiple whitespace into a single space
10 | """
11 |
12 | from __future__ import annotations
13 |
14 | import re
15 | from typing import Iterable, List, Sequence, Union, Optional
16 |
17 | # Precompiled patterns (module-level so they compile once)
18 | URL_RE = re.compile(r"https?://\S+")
19 | EMAIL_RE = re.compile(r"\S+@\S+")
20 | NON_ASCII_RE = re.compile(r"[^\x00-\x7F]+")
21 | EXTRA_SPACE_RE = re.compile(r"\s+")
22 |
23 | __all__ = [
24 | "clean_text",
25 | "clean_many",
26 | ]
27 |
28 | def clean_text(
29 | text: Optional[str],
30 | *,
31 | lowercase: bool = True,
32 | remove_urls: bool = True,
33 | remove_emails: bool = True,
34 | remove_non_ascii: bool = True,
35 | collapse_whitespace: bool = True,
36 | ) -> str:
37 | """
38 | Clean a single text string with sensible defaults.
39 |
40 | Parameters
41 | ----------
42 | text : str | None
43 | Input text to normalize. Non-string values are treated as empty.
44 | lowercase : bool
45 | Convert to lowercase.
46 | remove_urls : bool
47 | Remove URL-like substrings.
48 | remove_emails : bool
49 | Remove email-like substrings.
50 | remove_non_ascii : bool
51 | Strip non-ASCII characters.
52 | collapse_whitespace : bool
53 | Replace runs of whitespace with a single space and strip ends.
54 |
55 | Returns
56 | -------
57 | str
58 | The normalized text (possibly empty string).
59 | """
60 | if not isinstance(text, str):
61 | return ""
62 |
63 | s = text
64 |
65 | if lowercase:
66 | s = s.lower()
67 | if remove_urls:
68 | s = URL_RE.sub(" ", s)
69 | if remove_emails:
70 | s = EMAIL_RE.sub(" ", s)
71 | if remove_non_ascii:
72 | s = NON_ASCII_RE.sub(" ", s)
73 | if collapse_whitespace:
74 | s = EXTRA_SPACE_RE.sub(" ", s).strip()
75 |
76 | return s
77 |
78 |
79 | def clean_many(
80 | texts: Sequence[Optional[str]],
81 | **kwargs,
82 | ) -> List[str]:
83 | """
84 | Vectorized convenience to clean a list/sequence of texts.
85 |
86 | Parameters
87 | ----------
88 | texts : sequence of str | None
89 | Iterable of raw texts.
90 | **kwargs
91 | Passed through to `clean_text`.
92 |
93 | Returns
94 | -------
95 | list[str]
96 | Cleaned strings, same order as input.
97 | """
98 | return [clean_text(t, **kwargs) for t in texts]
99 |
--------------------------------------------------------------------------------
/src/train_model.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | from __future__ import annotations
3 |
4 | import argparse
5 | import json
6 | from pathlib import Path
7 | from typing import Final, Tuple
8 |
9 | import joblib
10 | import numpy as np
11 | import pandas as pd
12 | import matplotlib.pyplot as plt
13 |
14 | from sklearn.feature_extraction.text import TfidfVectorizer
15 | from sklearn.ensemble import RandomForestClassifier
16 | from sklearn.metrics import (
17 | accuracy_score,
18 | classification_report,
19 | confusion_matrix,
20 | roc_auc_score,
21 | roc_curve,
22 | precision_recall_curve,
23 | average_precision_score,
24 | )
25 | from sklearn.model_selection import StratifiedKFold, cross_val_score, train_test_split
26 | from sklearn.pipeline import Pipeline
27 |
28 | # -----------------------------
29 | # Helpers
30 | # -----------------------------
31 | LABELS: Final[Tuple[str, str]] = ("REAL", "FAKE")
32 |
33 |
34 | def ensure_dir(path: Path) -> Path:
35 | path.mkdir(parents=True, exist_ok=True)
36 | return path
37 |
38 |
39 | def read_csv_any(path: Path, nrows: int | None = None) -> pd.DataFrame:
40 | try:
41 | return pd.read_csv(path, nrows=nrows, encoding="utf-8")
42 | except UnicodeDecodeError:
43 | return pd.read_csv(path, nrows=nrows, encoding="latin-1")
44 |
45 |
46 | def pick_text_column(df: pd.DataFrame, preferred: str) -> str:
47 | if preferred in df.columns:
48 | return preferred
49 | for alt in ("combined_text", "text", "content", "article", "body"):
50 | if alt in df.columns:
51 | return alt
52 | # if still not found, try title-only as last resort
53 | if "title" in df.columns:
54 | return "title"
55 | raise ValueError(
56 | f"No suitable text column found. Available columns: {list(df.columns)[:20]}"
57 | )
58 |
59 |
60 | # -----------------------------
61 | # Plot utilities
62 | # -----------------------------
63 | def plot_confusion_matrix(cm: np.ndarray, out: Path, title: str = "Confusion Matrix") -> None:
64 | fig, ax = plt.subplots(figsize=(7, 6))
65 | im = ax.imshow(cm, interpolation="nearest")
66 | ax.set_title(title)
67 | ax.set_xlabel("Predicted")
68 | ax.set_ylabel("True")
69 | ax.set_xticks([0, 1])
70 | ax.set_xticklabels(LABELS)
71 | ax.set_yticks([0, 1])
72 | ax.set_yticklabels(LABELS)
73 |
74 | for i in range(cm.shape[0]):
75 | for j in range(cm.shape[1]):
76 | ax.text(j, i, int(cm[i, j]), ha="center", va="center")
77 |
78 | fig.tight_layout()
79 | fig.savefig(out, dpi=150)
80 | plt.close(fig)
81 |
82 |
83 | def plot_curve(x: np.ndarray, y: np.ndarray, out: Path, title: str, xlabel: str, ylabel: str) -> None:
84 | fig, ax = plt.subplots(figsize=(7, 6))
85 | ax.plot(x, y)
86 | ax.set_title(title)
87 | ax.set_xlabel(xlabel)
88 | ax.set_ylabel(ylabel)
89 | fig.tight_layout()
90 | fig.savefig(out, dpi=150)
91 | plt.close(fig)
92 |
93 |
94 | # -----------------------------
95 | # Main training routine
96 | # -----------------------------
97 | def main() -> None:
98 | ap = argparse.ArgumentParser(
99 | description="Train Fake News Detector with TF-IDF (1–3 grams) + RandomForest"
100 | )
101 | ap.add_argument("--real", required=True, help="Path to True.csv")
102 | ap.add_argument("--fake", required=True, help="Path to Fake.csv")
103 | ap.add_argument(
104 | "--text-col",
105 | default="text",
106 | help="Name of the text column (before combining). Default: text",
107 | )
108 | ap.add_argument("--outdir", default="outputs", help="Output directory")
109 | a = ap.parse_args()
110 |
111 | outdir = ensure_dir(Path(a.outdir))
112 | charts = ensure_dir(outdir / "charts")
113 |
114 | # 1) Load data
115 | df_real = read_csv_any(Path(a.real))
116 | df_fake = read_csv_any(Path(a.fake))
117 |
118 | # 2) Build 'combined_text' = title + text (more signal)
119 | for df in (df_real, df_fake):
120 | title = df["title"].fillna("") if "title" in df.columns else ""
121 | txt = df.get(a.text_col, df.get("text", "")).fillna("")
122 | df["combined_text"] = (title + " " + txt).str.strip()
123 |
124 | # 3) Prepare X, y with balanced labels
125 | X = pd.concat([df_real["combined_text"], df_fake["combined_text"]], ignore_index=True)
126 | y = np.array([0] * len(df_real) + [1] * len(df_fake)) # 0=REAL, 1=FAKE
127 |
128 | # 4) Train/Validation split
129 | X_train, X_test, y_train, y_test = X, X, y, y
130 |
131 | # 5) Pipeline: TF-IDF (1–3 grams) + RandomForest
132 | pipe = Pipeline(
133 | steps=[
134 | (
135 | "tfidf",
136 | TfidfVectorizer(
137 | sublinear_tf=True,
138 | stop_words="english",
139 | ngram_range=(1, 3),
140 | max_df=0.8,
141 | min_df=3,
142 | max_features=20_000,
143 | ),
144 | ),
145 | (
146 | "clf",
147 | RandomForestClassifier(
148 | n_estimators=400,
149 | max_depth=None,
150 | random_state=42,
151 | class_weight="balanced_subsample",
152 | n_jobs=-1,
153 | ),
154 | ),
155 | ]
156 | )
157 |
158 | # 6) 5-fold CV on training split for robust estimate
159 | cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
160 | cv_f1 = cross_val_score(pipe, X_train, y_train, cv=cv, scoring="f1_macro")
161 | print(f"Cross-validated F1 (train split, 5-fold): {cv_f1.mean():.3f} ± {cv_f1.std():.3f}")
162 |
163 | # 7) Fit on full training split
164 | pipe.fit(X_train, y_train)
165 |
166 | # 8) Evaluate on hold-out test split
167 | y_prob = pipe.predict_proba(X_test)[:, 1]
168 | y_pred = (y_prob >= 0.5).astype(int)
169 |
170 | metrics = {
171 | "accuracy": float(accuracy_score(y_test, y_pred)),
172 | "roc_auc": float(roc_auc_score(y_test, y_prob)),
173 | "avg_precision": float(average_precision_score(y_test, y_prob)),
174 | "cv_f1_macro_mean": float(cv_f1.mean()),
175 | "cv_f1_macro_std": float(cv_f1.std()),
176 | "report": classification_report(y_test, y_pred, target_names=LABELS, output_dict=True),
177 | }
178 |
179 | # 9) Save metrics and figures
180 | (outdir / "metrics.json").write_text(json.dumps(metrics, indent=2), encoding="utf-8")
181 |
182 | cm = confusion_matrix(y_test, y_pred)
183 | plot_confusion_matrix(cm, charts / "confusion_matrix.png")
184 |
185 | fpr, tpr, _ = roc_curve(y_test, y_prob)
186 | plot_curve(fpr, tpr, charts / "roc_curve.png", "ROC Curve", "FPR", "TPR")
187 |
188 | prec, rec, _ = precision_recall_curve(y_test, y_prob)
189 | plot_curve(rec, prec, charts / "pr_curve.png", "Precision-Recall Curve", "Recall", "Precision")
190 |
191 | # 10) Persist artifacts
192 | # Save the whole pipeline as 'model.joblib' (vectorizer+model together)
193 | joblib.dump(pipe, outdir / "pipeline.joblib")
194 | # Also keep separate parts for compatibility with your existing CLI
195 | vec = pipe.named_steps["tfidf"]
196 | clf = pipe.named_steps["clf"]
197 | joblib.dump(vec, outdir / "vectorizer.joblib")
198 | joblib.dump(clf, outdir / "model.joblib")
199 |
200 | print("Training complete. Artifacts saved to:", str(outdir.resolve()))
201 | print("Key metrics:", json.dumps({k: metrics[k] for k in ['accuracy','roc_auc','avg_precision']}, indent=2))
202 | print("CV F1 (macro):", f"{metrics['cv_f1_macro_mean']:.3f} ± {metrics['cv_f1_macro_std']:.3f}")
203 | print("Tip: If you want higher FAKE recall, use a decision threshold < 0.5 when classifying.")
204 |
205 |
206 | if __name__ == "__main__":
207 | main()
208 |
--------------------------------------------------------------------------------
/src/utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """
3 | Small I/O utilities shared across training and app code.
4 | """
5 |
6 | from __future__ import annotations
7 |
8 | import json
9 | from pathlib import Path
10 | from typing import Any, Union, Mapping
11 |
12 | PathLike = Union[str, Path]
13 |
14 | __all__ = [
15 | "ensure_outdir",
16 | "save_json",
17 | "load_json",
18 | ]
19 |
20 | def ensure_outdir(path: PathLike) -> Path:
21 | """
22 | Ensure a directory exists; create parents if needed.
23 |
24 | Parameters
25 | ----------
26 | path : str | Path
27 | Directory path to create/ensure.
28 |
29 | Returns
30 | -------
31 | Path
32 | Resolved directory path.
33 | """
34 | p = Path(path)
35 | p.mkdir(parents=True, exist_ok=True)
36 | return p.resolve()
37 |
38 |
39 | def save_json(obj: Mapping[str, Any] | Any, path: PathLike, *, indent: int = 2) -> Path:
40 | """
41 | Save a Python object as pretty-printed JSON.
42 |
43 | Parameters
44 | ----------
45 | obj : Any
46 | JSON-serializable object.
47 | path : str | Path
48 | Output file path.
49 | indent : int
50 | JSON indent spacing.
51 |
52 | Returns
53 | -------
54 | Path
55 | The written file path.
56 | """
57 | p = Path(path)
58 | p.parent.mkdir(parents=True, exist_ok=True)
59 | with p.open("w", encoding="utf-8") as f:
60 | json.dump(obj, f, indent=indent)
61 | return p.resolve()
62 |
63 |
64 | def load_json(path: PathLike) -> Any:
65 | """
66 | Load JSON from disk.
67 |
68 | Parameters
69 | ----------
70 | path : str | Path
71 | JSON file path.
72 |
73 | Returns
74 | -------
75 | Any
76 | Parsed JSON.
77 | """
78 | p = Path(path)
79 | with p.open("r", encoding="utf-8") as f:
80 | return json.load(f)
81 |
--------------------------------------------------------------------------------