├── requirements.txt ├── docs └── getting_started.md ├── src ├── toxicity_toolkit.egg-info │ ├── dependency_links.txt │ ├── top_level.txt │ ├── entry_points.txt │ ├── requires.txt │ ├── SOURCES.txt │ └── PKG-INFO └── toxicity_toolkit │ ├── __pycache__ │ └── cli.cpython-310.pyc │ ├── train │ ├── __pycache__ │ │ ├── metrics.cpython-310.pyc │ │ └── trainer.cpython-310.pyc │ ├── metrics.py │ └── trainer.py │ ├── data │ ├── __pycache__ │ │ └── datamodule.cpython-310.pyc │ ├── adapters │ │ ├── __pycache__ │ │ │ ├── jigsaw.cpython-310.pyc │ │ │ └── hatexplain.cpython-310.pyc │ │ ├── hatexplain.py │ │ └── jigsaw.py │ └── datamodule.py │ ├── infer │ ├── __pycache__ │ │ └── predictor.cpython-310.pyc │ └── predictor.py │ ├── explain │ ├── __pycache__ │ │ └── shap_explain.cpython-310.pyc │ └── shap_explain.py │ ├── models │ ├── __pycache__ │ │ └── multilabel_bert.cpython-310.pyc │ └── multilabel_bert.py │ └── cli.py ├── apps └── dashboard │ └── app.py ├── tests └── test_cli.py ├── pyproject.toml ├── README.md └── data └── jigsaw └── jigsaw └── preview.jsonl /requirements.txt: -------------------------------------------------------------------------------- 1 | -e . 2 | -------------------------------------------------------------------------------- /docs/getting_started.md: -------------------------------------------------------------------------------- 1 | # Getting Started -------------------------------------------------------------------------------- /src/toxicity_toolkit.egg-info/dependency_links.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /apps/dashboard/app.py: -------------------------------------------------------------------------------- 1 | # Streamlit dashboard placeholder 2 | -------------------------------------------------------------------------------- /src/toxicity_toolkit.egg-info/top_level.txt: -------------------------------------------------------------------------------- 1 | toxicity_toolkit 2 | -------------------------------------------------------------------------------- /tests/test_cli.py: -------------------------------------------------------------------------------- 1 | def test_placeholder(): 2 | assert True 3 | -------------------------------------------------------------------------------- /src/toxicity_toolkit.egg-info/entry_points.txt: -------------------------------------------------------------------------------- 1 | [console_scripts] 2 | toxdet = toxicity_toolkit.cli:app 3 | -------------------------------------------------------------------------------- /src/toxicity_toolkit/__pycache__/cli.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nealgatech-web/toxic-detection-toolkit/HEAD/src/toxicity_toolkit/__pycache__/cli.cpython-310.pyc -------------------------------------------------------------------------------- /src/toxicity_toolkit/train/__pycache__/metrics.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nealgatech-web/toxic-detection-toolkit/HEAD/src/toxicity_toolkit/train/__pycache__/metrics.cpython-310.pyc -------------------------------------------------------------------------------- /src/toxicity_toolkit/train/__pycache__/trainer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nealgatech-web/toxic-detection-toolkit/HEAD/src/toxicity_toolkit/train/__pycache__/trainer.cpython-310.pyc -------------------------------------------------------------------------------- /src/toxicity_toolkit/data/__pycache__/datamodule.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nealgatech-web/toxic-detection-toolkit/HEAD/src/toxicity_toolkit/data/__pycache__/datamodule.cpython-310.pyc -------------------------------------------------------------------------------- /src/toxicity_toolkit/infer/__pycache__/predictor.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nealgatech-web/toxic-detection-toolkit/HEAD/src/toxicity_toolkit/infer/__pycache__/predictor.cpython-310.pyc -------------------------------------------------------------------------------- /src/toxicity_toolkit/data/adapters/__pycache__/jigsaw.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nealgatech-web/toxic-detection-toolkit/HEAD/src/toxicity_toolkit/data/adapters/__pycache__/jigsaw.cpython-310.pyc -------------------------------------------------------------------------------- /src/toxicity_toolkit/explain/__pycache__/shap_explain.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nealgatech-web/toxic-detection-toolkit/HEAD/src/toxicity_toolkit/explain/__pycache__/shap_explain.cpython-310.pyc -------------------------------------------------------------------------------- /src/toxicity_toolkit/models/__pycache__/multilabel_bert.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nealgatech-web/toxic-detection-toolkit/HEAD/src/toxicity_toolkit/models/__pycache__/multilabel_bert.cpython-310.pyc -------------------------------------------------------------------------------- /src/toxicity_toolkit/data/adapters/__pycache__/hatexplain.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nealgatech-web/toxic-detection-toolkit/HEAD/src/toxicity_toolkit/data/adapters/__pycache__/hatexplain.cpython-310.pyc -------------------------------------------------------------------------------- /src/toxicity_toolkit.egg-info/requires.txt: -------------------------------------------------------------------------------- 1 | transformers>=4.44.0 2 | datasets>=2.20.0 3 | torch>=2.1.0 4 | scikit-learn>=1.3.0 5 | pandas>=2.1.0 6 | numpy>=1.26.0 7 | typer>=0.12.3 8 | rich>=13.7.0 9 | pyyaml>=6.0.1 10 | tqdm>=4.66.0 11 | fastapi>=0.110.0 12 | uvicorn>=0.29.0 13 | shap>=0.44.0 14 | lime>=0.2.0.1 15 | streamlit>=1.37.0 16 | plotly>=5.22.0 17 | -------------------------------------------------------------------------------- /src/toxicity_toolkit/data/adapters/hatexplain.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | def load_hatexplain(split: str = "train"): 3 | ds = load_dataset("hatexplain", split=split) 4 | out = [] 5 | for row in ds: 6 | text = " ".join(row.get("post_tokens", [])) or row.get("text", "") 7 | labels = ["hate"] if row.get("label", 0) == 1 else [] 8 | out.append({"text": text, "labels": labels}) 9 | return out 10 | -------------------------------------------------------------------------------- /src/toxicity_toolkit/train/metrics.py: -------------------------------------------------------------------------------- 1 | from sklearn.metrics import f1_score, precision_score, recall_score 2 | 3 | def multilabel_metrics(y_true, y_pred, average="macro"): 4 | return { 5 | "f1": float(f1_score(y_true, y_pred, average=average, zero_division=0)), 6 | "precision": float(precision_score(y_true, y_pred, average=average, zero_division=0)), 7 | "recall": float(recall_score(y_true, y_pred, average=average, zero_division=0)), 8 | } 9 | -------------------------------------------------------------------------------- /src/toxicity_toolkit/models/multilabel_bert.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from transformers import AutoModel 3 | 4 | class MultiLabelBertHead(nn.Module): 5 | def __init__(self, base, num_labels): 6 | super().__init__() 7 | self.encoder = AutoModel.from_pretrained(base) 8 | hidden = self.encoder.config.hidden_size 9 | self.dropout = nn.Dropout(0.1) 10 | self.classifier = nn.Linear(hidden, num_labels) 11 | 12 | def forward(self, input_ids, attention_mask): 13 | out = self.encoder(input_ids=input_ids, attention_mask=attention_mask) 14 | cls = out.last_hidden_state[:,0,:] 15 | return self.classifier(self.dropout(cls)) 16 | -------------------------------------------------------------------------------- /src/toxicity_toolkit.egg-info/SOURCES.txt: -------------------------------------------------------------------------------- 1 | README.md 2 | pyproject.toml 3 | src/toxicity_toolkit/cli.py 4 | src/toxicity_toolkit.egg-info/PKG-INFO 5 | src/toxicity_toolkit.egg-info/SOURCES.txt 6 | src/toxicity_toolkit.egg-info/dependency_links.txt 7 | src/toxicity_toolkit.egg-info/entry_points.txt 8 | src/toxicity_toolkit.egg-info/requires.txt 9 | src/toxicity_toolkit.egg-info/top_level.txt 10 | src/toxicity_toolkit/data/datamodule.py 11 | src/toxicity_toolkit/data/adapters/hatexplain.py 12 | src/toxicity_toolkit/data/adapters/jigsaw.py 13 | src/toxicity_toolkit/explain/shap_explain.py 14 | src/toxicity_toolkit/infer/predictor.py 15 | src/toxicity_toolkit/models/multilabel_bert.py 16 | src/toxicity_toolkit/train/metrics.py 17 | src/toxicity_toolkit/train/trainer.py 18 | tests/test_cli.py -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=68", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "toxicity-toolkit" 7 | version = "0.1.0" 8 | description = "Open-source modular toxic content detection toolkit" 9 | authors = [{ name = "Community Contributors" }] 10 | readme = "README.md" 11 | requires-python = ">=3.9" 12 | dependencies = [ 13 | "transformers>=4.44.0", 14 | "datasets>=2.20.0", 15 | "torch>=2.1.0", 16 | "scikit-learn>=1.3.0", 17 | "pandas>=2.1.0", 18 | "numpy>=1.26.0", 19 | "typer>=0.12.3", 20 | "rich>=13.7.0", 21 | "pyyaml>=6.0.1", 22 | "tqdm>=4.66.0", 23 | "fastapi>=0.110.0", 24 | "uvicorn>=0.29.0", 25 | "shap>=0.44.0", 26 | "lime>=0.2.0.1", 27 | "streamlit>=1.37.0", 28 | "plotly>=5.22.0" 29 | ] 30 | 31 | [project.scripts] 32 | toxdet = "toxicity_toolkit.cli:app" 33 | -------------------------------------------------------------------------------- /src/toxicity_toolkit/explain/shap_explain.py: -------------------------------------------------------------------------------- 1 | import shap 2 | import numpy as np 3 | import torch 4 | from ..infer.predictor import Predictor 5 | 6 | def shap_explain_text(run_dir, text: str): 7 | """ 8 | Generate SHAP explanations for a single text input using a trained model. 9 | Requires that the model directory (run_dir) contains the model + tokenizer. 10 | """ 11 | print(f"🔍 Explaining text with SHAP: {text[:80]}...") 12 | pred = Predictor.from_run(run_dir) 13 | tokenizer = pred.tokenizer 14 | 15 | # Define wrapper function that returns model probabilities 16 | def f(texts): 17 | outs = pred.predict(texts) 18 | return np.array([[o["scores"][lab] for lab in pred.labels] for o in outs]) 19 | 20 | # SHAP text masker and explainer 21 | explainer = shap.Explainer(f, shap.maskers.Text(tokenizer.sep_token or " ")) 22 | shap_values = explainer([text]) 23 | 24 | # Display in notebook or supported environments 25 | shap.plots.text(shap_values[0]) 26 | print("✅ SHAP explanation generated.") 27 | -------------------------------------------------------------------------------- /src/toxicity_toolkit/data/adapters/jigsaw.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapter for Civil Comments (Jigsaw Toxic Comment replacement). 3 | Fully works with Hugging Face `datasets` >= 3.x. 4 | """ 5 | 6 | from datasets import load_dataset 7 | from typing import List, Dict 8 | 9 | # Map columns to your label set 10 | LABELS = [ 11 | "toxicity", 12 | "severe_toxicity", 13 | "obscene", 14 | "threat", 15 | "insult", 16 | "identity_attack", 17 | ] 18 | 19 | 20 | def load_jigsaw(split: str = "train") -> List[Dict]: 21 | """ 22 | Load and convert Civil Comments dataset into {text, labels[]} schema. 23 | This is equivalent to the Jigsaw Toxic Comments data. 24 | """ 25 | ds = load_dataset("civil_comments", split=split) 26 | out = [] 27 | 28 | for row in ds: 29 | text = row.get("text") or row.get("comment_text") or "" 30 | labels = [] 31 | for lab in LABELS: 32 | # the civil_comments labels are probabilities between 0–1 33 | if float(row.get(lab, 0.0)) >= 0.5: 34 | labels.append(lab) 35 | out.append({"text": text, "labels": labels}) 36 | 37 | return out 38 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Toxic Content Detection Toolkit 2 | 3 | Open-source, modular pipeline for detecting and classifying toxic text across multiple domains. 4 | 5 | ## Key Features 6 | - Multi-label classification: hate, harassment, misinformation, spam (extensible) 7 | - Dataset adapters: HateXplain, Civil Comments (Jigsaw replacement), Reddit — pluggable interface 8 | - Pretrained checkpoints: BERT/DistilBERT fine-tuned (community-hostable) 9 | - Visualization dashboard: toxicity heatmaps by topic/community 10 | - Explainability: SHAP or LIME for model interpretability 11 | - Batteries included: CLI, config, evaluation, unit tests, CI, model cards 12 | 13 | ## Quickstart 14 | 15 | ```bash 16 | # 1. Install 17 | pip3 install -e . 18 | 19 | # 2. Explore the CLI 20 | toxdet --help 21 | 22 | # 3. Download and preprocess a dataset 23 | # You can choose one of the supported datasets: 24 | # - jigsaw → uses the Civil Comments dataset (recommended) 25 | # - hatexplain → optional, requires manual download 26 | toxdet data prepare --dataset jigsaw --out data/jigsaw 27 | 28 | # 4. Train a multi-label model 29 | toxdet train --dataset jigsaw --model bert-base-uncased --epochs 1 --output runs/bert_cc 30 | 31 | # 5. Evaluate 32 | toxdet eval runs/bert_cc --split validation 33 | 34 | # 6. Inference (JSONL in, JSON out) 35 | echo '{"text": "You are a terrible person"}' | toxdet infer runs/bert_cc --threshold 0.5 36 | 37 | # 7. Explain a prediction with SHAP 38 | toxdet explain shap --run runs/bert_cc --text "Example text to explain" 39 | -------------------------------------------------------------------------------- /src/toxicity_toolkit/data/datamodule.py: -------------------------------------------------------------------------------- 1 | """ 2 | DataModule — unified loader interface for different datasets. 3 | """ 4 | 5 | from typing import List, Dict 6 | from pathlib import Path 7 | from .adapters.hatexplain import load_hatexplain 8 | from .adapters.jigsaw import load_jigsaw 9 | 10 | 11 | class DataModule: 12 | """ 13 | Wraps dataset preparation and access for supported datasets. 14 | """ 15 | 16 | def __init__(self, name: str, out_dir: Path | None = None): 17 | self.name = name.lower() 18 | self.out_dir = Path(out_dir or "data") / self.name 19 | 20 | # ------------------------------------------------------------------ 21 | @classmethod 22 | def from_name(cls, name: str, out_dir: Path | None = None): 23 | return cls(name, out_dir) 24 | 25 | # ------------------------------------------------------------------ 26 | def prepare(self): 27 | """ 28 | Download / preprocess dataset and cache locally. 29 | """ 30 | self.out_dir.mkdir(parents=True, exist_ok=True) 31 | 32 | if self.name == "hatexplain": 33 | data = load_hatexplain("train") 34 | elif self.name == "jigsaw": 35 | data = load_jigsaw("train") 36 | else: 37 | raise ValueError(f"Unknown dataset: {self.name}") 38 | 39 | # Save a preview so user can verify 40 | import json 41 | preview_path = self.out_dir / "preview.jsonl" 42 | with open(preview_path, "w", encoding="utf-8") as f: 43 | for row in data[:20]: 44 | f.write(json.dumps(row, ensure_ascii=False) + "\n") 45 | print(f"✅ Saved sample preview to {preview_path}") 46 | 47 | # ------------------------------------------------------------------ 48 | def load(self, split: str = "train") -> List[Dict]: 49 | """ 50 | Load a split into memory for training or evaluation. 51 | """ 52 | if self.name == "hatexplain": 53 | return load_hatexplain(split) 54 | elif self.name == "jigsaw": 55 | return load_jigsaw(split) 56 | else: 57 | raise ValueError(f"Unknown dataset: {self.name}") 58 | -------------------------------------------------------------------------------- /src/toxicity_toolkit.egg-info/PKG-INFO: -------------------------------------------------------------------------------- 1 | Metadata-Version: 2.4 2 | Name: toxicity-toolkit 3 | Version: 0.1.0 4 | Summary: Open-source modular toxic content detection toolkit 5 | Author: Community Contributors 6 | Requires-Python: >=3.9 7 | Description-Content-Type: text/markdown 8 | Requires-Dist: transformers>=4.44.0 9 | Requires-Dist: datasets>=2.20.0 10 | Requires-Dist: torch>=2.1.0 11 | Requires-Dist: scikit-learn>=1.3.0 12 | Requires-Dist: pandas>=2.1.0 13 | Requires-Dist: numpy>=1.26.0 14 | Requires-Dist: typer>=0.12.3 15 | Requires-Dist: rich>=13.7.0 16 | Requires-Dist: pyyaml>=6.0.1 17 | Requires-Dist: tqdm>=4.66.0 18 | Requires-Dist: fastapi>=0.110.0 19 | Requires-Dist: uvicorn>=0.29.0 20 | Requires-Dist: shap>=0.44.0 21 | Requires-Dist: lime>=0.2.0.1 22 | Requires-Dist: streamlit>=1.37.0 23 | Requires-Dist: plotly>=5.22.0 24 | 25 | # Toxic Content Detection Toolkit 26 | 27 | Open-source, modular pipeline for detecting and classifying toxic text across multiple domains. 28 | 29 | ## Key Features 30 | - **Multi-label classification**: hate, harassment, misinformation, spam (extensible). 31 | - **Dataset adapters**: HateXplain, Jigsaw, Reddit; pluggable interface. 32 | - **Pretrained checkpoints**: BERT/DistilBERT fine-tuned (community-hostable). 33 | - **Visualization dashboard**: toxicity heatmaps by topic/community. 34 | - **Explainability**: SHAP or LIME to interpret model decisions. 35 | - **Batteries included**: CLI, config, evaluation, unit tests, CI, model cards. 36 | 37 | ## Quickstart 38 | 39 | ```bash 40 | # 1) Install 41 | pip3 install -e . 42 | 43 | # 2) Explore CLI 44 | toxdet --help 45 | 46 | # 3) Download + preprocess a dataset (e.g., HateXplain) 47 | toxdet data prepare --dataset hatexplain --out data/hatexplain 48 | 49 | # 4) Train a multi-label model 50 | toxdet train --dataset hatexplain --model bert-base-uncased --epochs 3 --output runs/bert_hx 51 | 52 | # 5) Evaluate 53 | toxdet eval --run runs/bert_hx --split valid 54 | 55 | # 6) Inference (JSONL in, JSON out) 56 | echo '{"text": "I disagree with you"}' | toxdet infer --run runs/bert_hx --threshold 0.5 57 | 58 | # 7) Explain a prediction with SHAP 59 | toxdet explain shap --run runs/bert_hx --text "Example text to explain" 60 | -------------------------------------------------------------------------------- /src/toxicity_toolkit/infer/predictor.py: -------------------------------------------------------------------------------- 1 | import os, json, torch 2 | import numpy as np 3 | from transformers import AutoTokenizer 4 | from ..models.multilabel_bert import MultiLabelBertHead 5 | 6 | class Predictor: 7 | def __init__(self, tokenizer, model, labels): 8 | self.tokenizer = tokenizer 9 | self.model = model 10 | self.labels = labels 11 | self.device = "cuda" if torch.cuda.is_available() else "cpu" 12 | self.model.eval() 13 | self.model.to(self.device) 14 | 15 | @classmethod 16 | def from_run(cls, run_dir): 17 | run_dir = str(run_dir) 18 | labels_path = os.path.join(run_dir, "labels.json") 19 | model_path = os.path.join(run_dir, "pytorch_model.bin") 20 | if not os.path.exists(model_path): 21 | raise FileNotFoundError(f"Missing model checkpoint at {model_path}") 22 | if not os.path.exists(labels_path): 23 | raise FileNotFoundError(f"Missing labels.json at {labels_path}") 24 | 25 | with open(labels_path, "r") as f: 26 | labels = json.load(f) 27 | tokenizer = AutoTokenizer.from_pretrained(run_dir) 28 | model = MultiLabelBertHead("bert-base-uncased", num_labels=len(labels)) 29 | model.load_state_dict(torch.load(model_path, map_location="cpu")) 30 | return cls(tokenizer, model, labels) 31 | 32 | def predict(self, texts, threshold=0.5): 33 | """Predict labels for a list of input texts.""" 34 | results = [] 35 | for text in texts: 36 | enc = self.tokenizer( 37 | text, 38 | truncation=True, 39 | padding="max_length", 40 | max_length=256, 41 | return_tensors="pt", 42 | ) 43 | with torch.no_grad(): 44 | logits = self.model( 45 | enc["input_ids"].to(self.device), 46 | enc["attention_mask"].to(self.device), 47 | ) 48 | probs = torch.sigmoid(logits).cpu().numpy().ravel() 49 | labels = [self.labels[i] for i, p in enumerate(probs) if p >= threshold] 50 | results.append( 51 | { 52 | "text": text, 53 | "scores": dict(zip(self.labels, probs.tolist())), 54 | "labels": labels, 55 | } 56 | ) 57 | return results 58 | 59 | def evaluate(self, split="validation", threshold=0.5): 60 | """Placeholder evaluation for CLI. Extend with your validation logic.""" 61 | print(f"Evaluating on split: {split} (threshold={threshold})") 62 | return {"split": split, "threshold": threshold, "macro_f1": 0.0} 63 | -------------------------------------------------------------------------------- /src/toxicity_toolkit/train/trainer.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import torch, os, json 3 | from torch.utils.data import DataLoader, Dataset 4 | from transformers import AutoTokenizer, get_linear_schedule_with_warmup 5 | from ..models.multilabel_bert import MultiLabelBertHead 6 | from ..data.datamodule import DataModule 7 | from .metrics import multilabel_metrics 8 | 9 | class SimpleDataset(Dataset): 10 | def __init__(self, rows, tokenizer, labels): 11 | self.rows, self.tokenizer, self.labels = rows, tokenizer, labels 12 | self.lab2id = {l: i for i, l in enumerate(labels)} 13 | 14 | def __len__(self): 15 | return len(self.rows) 16 | 17 | def __getitem__(self, i): 18 | r = self.rows[i] 19 | enc = self.tokenizer( 20 | r["text"], 21 | truncation=True, 22 | padding="max_length", 23 | max_length=256, 24 | return_tensors="pt", 25 | ) 26 | y = torch.zeros(len(self.labels)) 27 | for l in r.get("labels", []): 28 | if l in self.lab2id: 29 | y[self.lab2id[l]] = 1.0 30 | return enc["input_ids"].squeeze(0), enc["attention_mask"].squeeze(0), y 31 | 32 | 33 | class Trainer: 34 | def __init__(self, model_name: str, labels: List[str], out_dir): 35 | self.labels = labels 36 | self.out_dir = os.path.abspath(out_dir) 37 | os.makedirs(self.out_dir, exist_ok=True) 38 | self.tokenizer = AutoTokenizer.from_pretrained(model_name) 39 | self.model = MultiLabelBertHead(model_name, num_labels=len(labels)) 40 | self.model.train() 41 | 42 | def fit(self, dm: DataModule, epochs=3, batch_size=16, lr=2e-5): 43 | print(f"Loading dataset: {dm.name}") 44 | rows = dm.load("train") 45 | ds = SimpleDataset(rows, self.tokenizer, self.labels) 46 | dl = DataLoader(ds, batch_size=batch_size, shuffle=True) 47 | 48 | opt = torch.optim.AdamW(self.model.parameters(), lr=lr) 49 | total_steps = epochs * len(dl) 50 | sched = get_linear_schedule_with_warmup(opt, 0, total_steps) 51 | bce = torch.nn.BCEWithLogitsLoss() 52 | device = "cuda" if torch.cuda.is_available() else "cpu" 53 | self.model.to(device) 54 | 55 | for ep in range(epochs): 56 | total_loss = 0.0 57 | for input_ids, attn, y in dl: 58 | input_ids, attn, y = input_ids.to(device), attn.to(device), y.to(device) 59 | logits = self.model(input_ids, attn) 60 | loss = bce(logits, y) 61 | loss.backward() 62 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) 63 | opt.step() 64 | sched.step() 65 | opt.zero_grad() 66 | total_loss += loss.item() 67 | print(f"Epoch {ep+1}/{epochs} - Loss: {total_loss/len(dl):.4f}") 68 | 69 | torch.save(self.model.state_dict(), os.path.join(self.out_dir, "pytorch_model.bin")) 70 | self.tokenizer.save_pretrained(self.out_dir) 71 | with open(os.path.join(self.out_dir, "labels.json"), "w") as f: 72 | json.dump(self.labels, f) 73 | print(f"✅ Model saved to {self.out_dir}") 74 | -------------------------------------------------------------------------------- /data/jigsaw/jigsaw/preview.jsonl: -------------------------------------------------------------------------------- 1 | {"text": "This is so cool. It's like, 'would you want your mother to read this??' Really great idea, well done!", "labels": []} 2 | {"text": "Thank you!! This would make my life a lot less anxiety-inducing. Keep it up, and don't let anyone get in your way!", "labels": []} 3 | {"text": "This is such an urgent design problem; kudos to you for taking it on. Very impressive!", "labels": []} 4 | {"text": "Is this something I'll be able to install on my site? When will you be releasing it?", "labels": []} 5 | {"text": "haha you guys are a bunch of losers.", "labels": ["toxicity", "insult"]} 6 | {"text": "ur a sh*tty comment.", "labels": ["toxicity", "obscene"]} 7 | {"text": "hahahahahahahahhha suck it.", "labels": []} 8 | {"text": "FFFFUUUUUUUUUUUUUUU", "labels": []} 9 | {"text": "The ranchers seem motivated by mostly by greed; no one should have the right to allow their animals destroy public land.", "labels": []} 10 | {"text": "It was a great show. Not a combo I'd of expected to be good together but it was.", "labels": []} 11 | {"text": "Wow, that sounds great.", "labels": []} 12 | {"text": "This is a great story. Man. I wonder if the person who yelled \"shut the fuck up!\" at him ever heard it.", "labels": []} 13 | {"text": "This seems like a step in the right direction.", "labels": []} 14 | {"text": "It's ridiculous that these guys are being called \"protesters\". Being armed is a threat of violence, which makes them terrorists.", "labels": ["toxicity", "insult"]} 15 | {"text": "This story gets more ridiculous by the hour! And, I love that people are sending these guys dildos in the mail now. But… if they really think there's a happy ending in this for any of them, I think they're even more deluded than all of the jokes about them assume.", "labels": ["toxicity"]} 16 | {"text": "I agree; I don't want to grant them the legitimacy of protestors. They're greedy, small-minded people who somehow seem to share the mass delusion that this is not only a good idea for themselves as individuals, but is the right thing to do for ranchers at large. Basically: take something that currently belongs to everyone, and give it to a select group of people, so they can profit.", "labels": []} 17 | {"text": "Interesting. I'll be curious to see how this works out. I often refrain from commenting because I don't have the time or desire to engage with the couple of resident trolls who seem to jump on every active WW comment thread.", "labels": []} 18 | {"text": "Awesome! I love Civil Comments!", "labels": []} 19 | {"text": "I'm glad you're working on this, and I look forward to seeing how it plays out. The comments sections of online news stories have the potential to be great tools for community interaction about current events, a Neo Town Hall, of sorts. One of the reasons I rely on Reddit as a platform for news and local discussions is that there's a sense of community interaction that's often lacking in my \"real life,\" hectic discussions. \n\nBut hopefully we won't be tempted to silence those who take unpopular stances.", "labels": []} 20 | {"text": "Angry trolls, misogynists and Racists\", oh my. It doesn't take all of my 150 IQ to see the slant here. it's the \"Diversity diode\" at work yet again. \"We can say anything that we want because we are Diversity. You on the other hand must only say what we allow you to say. From now on, winning arguments against any member of diversity will be considered offensive language. facts, cogent, linear posts and Math are now verboten.", "labels": ["toxicity", "insult"]} 21 | -------------------------------------------------------------------------------- /src/toxicity_toolkit/cli.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | import json, sys, pathlib 3 | import typer 4 | from rich import print as rprint 5 | 6 | from .data.datamodule import DataModule 7 | from .train.trainer import Trainer 8 | from .infer.predictor import Predictor 9 | from .explain.shap_explain import shap_explain_text 10 | 11 | # Root app 12 | app = typer.Typer(help="Toxic Content Detection Toolkit CLI") 13 | 14 | # Data group for subcommands 15 | data_app = typer.Typer(help="Dataset utilities: download, prepare, inspect") 16 | app.add_typer(data_app, name="data") 17 | 18 | 19 | # ============================================================ 20 | # DATA SUBCOMMANDS 21 | # ============================================================ 22 | 23 | @data_app.command("prepare") 24 | def data_prepare( 25 | dataset: str = typer.Option("jigsaw", help="Dataset name: 'jigsaw' or 'hatexplain'"), 26 | out: pathlib.Path = typer.Option(pathlib.Path("data") / "prepared", help="Output directory"), 27 | ): 28 | """ 29 | Download and preprocess a dataset (HateXplain or Jigsaw Toxic Comments). 30 | """ 31 | dm = DataModule.from_name(dataset, out_dir=out) 32 | dm.prepare() 33 | rprint(f"[green]✅ Prepared {dataset} dataset at {out}[/green]") 34 | 35 | 36 | # ============================================================ 37 | # TRAIN COMMAND 38 | # ============================================================ 39 | 40 | @app.command("train") 41 | def train( 42 | dataset: str = "jigsaw", 43 | model: str = "bert-base-uncased", 44 | epochs: int = 1, 45 | output: pathlib.Path = pathlib.Path("runs") / "bert_model", 46 | batch_size: int = 16, 47 | lr: float = 2e-5, 48 | labels: Optional[str] = None, 49 | ): 50 | """ 51 | Train a multi-label classifier on a supported dataset. 52 | """ 53 | dm = DataModule.from_name(dataset) 54 | 55 | # Label defaults depending on dataset 56 | if labels: 57 | label_list = labels.split(",") 58 | elif dataset.lower() == "jigsaw": 59 | label_list = ["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"] 60 | else: 61 | label_list = ["hate", "harassment", "misinformation", "spam"] 62 | 63 | trainer = Trainer(model_name=model, labels=label_list, out_dir=output) 64 | trainer.fit(dm, epochs=epochs, batch_size=batch_size, lr=lr) 65 | rprint(f"[green]✅ Training complete. Model saved to {output}[/green]") 66 | 67 | 68 | # ============================================================ 69 | # EVAL COMMAND 70 | # ============================================================ 71 | 72 | @app.command("eval") 73 | def evaluate( 74 | run: pathlib.Path = typer.Argument(..., help="Run/checkpoint directory"), 75 | split: str = "validation", 76 | threshold: float = 0.5, 77 | ): 78 | """ 79 | Evaluate a trained checkpoint on a given split. 80 | """ 81 | pred = Predictor.from_run(run) 82 | metrics = pred.evaluate(split=split, threshold=threshold) 83 | rprint(metrics) 84 | 85 | 86 | # ============================================================ 87 | # INFER COMMAND 88 | # ============================================================ 89 | 90 | @app.command("infer") 91 | def infer( 92 | run: pathlib.Path = typer.Argument(..., help="Run/checkpoint directory"), 93 | threshold: float = 0.5, 94 | ): 95 | """ 96 | Run inference from stdin (pipe or echo a text string). 97 | Example: 98 | echo "You are terrible!" | toxdet infer runs/bert_model 99 | """ 100 | pred = Predictor.from_run(run) 101 | text = sys.stdin.read().strip() 102 | if not text: 103 | rprint("[red]No input text provided![/red]") 104 | raise typer.Exit(1) 105 | result = pred.predict([text], threshold=threshold)[0] 106 | rprint(json.dumps(result, ensure_ascii=False, indent=2)) 107 | 108 | 109 | # ============================================================ 110 | # EXPLAIN COMMAND 111 | # ============================================================ 112 | 113 | @app.command("explain") 114 | def explain( 115 | method: str = "shap", 116 | run: pathlib.Path = pathlib.Path("runs/bert_model"), 117 | text: str = "Example toxic text", 118 | ): 119 | """ 120 | Explain model predictions for a given text using SHAP or LIME. 121 | """ 122 | if method.lower() == "shap": 123 | shap_explain_text(run, text) 124 | else: 125 | rprint("[yellow]Only SHAP demo implemented in CLI for now.[/yellow]") 126 | 127 | 128 | # ============================================================ 129 | # ENTRYPOINT 130 | # ============================================================ 131 | 132 | if __name__ == "__main__": 133 | app() 134 | --------------------------------------------------------------------------------