"""
Phase 7: ML-Based Threat Classifier
Uses simple but effective ML techniques for:
- Attack pattern classification (SQLi, XSS, RCE, etc.)
- Anomaly scoring based on request features
- Risk-level prediction for proposed rules
Designed to work offline without heavy dependencies.
Uses scikit-learn-style interface but can run with pure Python fallback.
"""
from __future__ import annotations
import hashlib
import json
import math
import re
from collections import Counter, defaultdict
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Tuple
# Try to import sklearn, fall back to pure Python
try:
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.naive_bayes import MultinomialNB
from sklearn.preprocessing import LabelEncoder
HAS_SKLEARN = True
except ImportError:
HAS_SKLEARN = False
@dataclass
class ClassificationResult:
"""Result of classifying a threat indicator or pattern."""
label: str # "sqli", "xss", "rce", "path_traversal", "scanner", "benign", etc.
confidence: float # 0.0-1.0
probabilities: Dict[str, float] = field(default_factory=dict)
features_used: List[str] = field(default_factory=list)
explanation: str = ""
@dataclass
class AnomalyScore:
"""Anomaly detection result."""
score: float # 0.0-1.0 (higher = more anomalous)
baseline_deviation: float # standard deviations from mean
anomalous_features: List[str] = field(default_factory=list)
recommendation: str = ""
class FeatureExtractor:
"""Extract features from request/log data for ML classification."""
# Character distribution features
SPECIAL_CHARS = set("'\"<>(){}[];=&|`$\\")
# Known attack signatures for feature detection
SQLI_PATTERNS = [
r"(?i)union\s+select",
r"(?i)select\s+.*\s+from",
r"(?i)insert\s+into",
r"(?i)update\s+.*\s+set",
r"(?i)delete\s+from",
r"(?i)drop\s+table",
r"(?i);\s*--",
r"(?i)'\s*or\s+'?1'?\s*=\s*'?1",
r"(?i)'\s*and\s+'?1'?\s*=\s*'?1",
]
XSS_PATTERNS = [
r"(?i)", "xss"),
(" ", "xss"),
("javascript:alert(document.cookie)", "xss"),
("", "xss"),
("'\">", "xss"),
# RCE examples
("; cat /etc/passwd", "rce"),
("| ls -la", "rce"),
("`id`", "rce"),
("$(whoami)", "rce"),
("; rm -rf /", "rce"),
("system('cat /etc/passwd')", "rce"),
# Path traversal
("../../../etc/passwd", "path_traversal"),
("..\\..\\..\\windows\\system32\\config\\sam", "path_traversal"),
("/etc/passwd%00", "path_traversal"),
("....//....//etc/passwd", "path_traversal"),
# Scanner signatures
("Mozilla/5.0 (compatible; Nmap Scripting Engine)", "scanner"),
("sqlmap/1.0", "scanner"),
("Nikto/2.1.5", "scanner"),
("masscan/1.0", "scanner"),
# Benign examples
("/api/users/123", "benign"),
("Mozilla/5.0 (Windows NT 10.0; Win64; x64)", "benign"),
("/products?category=electronics&page=2", "benign"),
("GET /index.html HTTP/1.1", "benign"),
("/static/css/main.css", "benign"),
]
def train(self) -> None:
"""Train the classifier on built-in examples."""
# Extract features for all training data
X: List[Dict[str, float]] = []
y: List[str] = []
for text, label in self._training_data:
features = self.feature_extractor.extract(text)
X.append(features)
y.append(label)
# Calculate class priors
label_counts = Counter(y)
total = len(y)
for label, count in label_counts.items():
self._class_priors[label] = count / total
# Calculate feature means and variances per class
all_features = set()
for features in X:
all_features.update(features.keys())
for label in self.LABELS:
class_features = [X[i] for i in range(len(X)) if y[i] == label]
if not class_features:
continue
for feature in all_features:
values = [f.get(feature, 0.0) for f in class_features]
mean = sum(values) / len(values)
var = sum((v - mean) ** 2 for v in values) / len(values)
self._feature_means[label][feature] = mean
self._feature_vars[label][feature] = max(var, 1e-6) # avoid division by zero
self._trained = True
def classify(self, text: str) -> ClassificationResult:
"""Classify text into attack category."""
if not self._trained:
self.train()
features = self.feature_extractor.extract(text)
# Calculate log probabilities for each class
log_probs: Dict[str, float] = {}
for label in self.LABELS:
if label not in self._class_priors:
continue
log_prob = math.log(self._class_priors[label])
for feature, value in features.items():
if feature in self._feature_means[label]:
mean = self._feature_means[label][feature]
var = self._feature_vars[label][feature]
# Gaussian likelihood
log_prob += -0.5 * math.log(2 * math.pi * var)
log_prob += -0.5 * ((value - mean) ** 2) / var
log_probs[label] = log_prob
# Convert to probabilities via softmax
max_log_prob = max(log_probs.values()) if log_probs else 0
exp_probs = {k: math.exp(v - max_log_prob) for k, v in log_probs.items()}
total = sum(exp_probs.values())
probs = {k: v / total for k, v in exp_probs.items()}
# Find best label
best_label = max(probs, key=probs.get) if probs else "benign"
confidence = probs.get(best_label, 0.0)
# Generate explanation
explanation = self._generate_explanation(text, features, best_label)
return ClassificationResult(
label=best_label,
confidence=confidence,
probabilities=probs,
features_used=list(features.keys()),
explanation=explanation
)
def _generate_explanation(self, text: str, features: Dict[str, float], label: str) -> str:
"""Generate human-readable explanation for classification."""
reasons = []
if features.get("sqli_score", 0) > 0.3:
reasons.append("SQL injection patterns detected")
if features.get("xss_score", 0) > 0.3:
reasons.append("XSS patterns detected")
if features.get("rce_score", 0) > 0.3:
reasons.append("Command injection patterns detected")
if features.get("path_traversal_score", 0) > 0.3:
reasons.append("Path traversal patterns detected")
if features.get("special_char_ratio", 0) > 0.2:
reasons.append("High special character ratio")
if features.get("entropy", 0) > 0.7:
reasons.append("High entropy (possible encoding/obfuscation)")
if not reasons:
reasons.append(f"General pattern matching suggests {label}")
return "; ".join(reasons)
class AnomalyDetector:
"""
Detect anomalous requests based on baseline behavior.
Uses statistical methods (z-score, IQR) without requiring ML libraries.
"""
def __init__(self):
self.feature_extractor = FeatureExtractor()
self._baseline_stats: Dict[str, Dict[str, float]] = {}
self._observations: List[Dict[str, float]] = []
def add_observation(self, text: str) -> None:
"""Add an observation to the baseline."""
features = self.feature_extractor.extract(text)
self._observations.append(features)
# Recalculate baseline after enough observations
if len(self._observations) >= 10:
self._update_baseline()
def _update_baseline(self) -> None:
"""Update baseline statistics."""
if not self._observations:
return
all_features = set()
for obs in self._observations:
all_features.update(obs.keys())
for feature in all_features:
values = [obs.get(feature, 0.0) for obs in self._observations]
mean = sum(values) / len(values)
var = sum((v - mean) ** 2 for v in values) / len(values)
std = math.sqrt(var) if var > 0 else 0.001
self._baseline_stats[feature] = {
"mean": mean,
"std": std,
"min": min(values),
"max": max(values),
}
def score(self, text: str) -> AnomalyScore:
"""Score how anomalous a request is."""
features = self.feature_extractor.extract(text)
if not self._baseline_stats:
# No baseline yet, use heuristics
return self._heuristic_score(features)
z_scores: Dict[str, float] = {}
anomalous_features: List[str] = []
for feature, value in features.items():
if feature in self._baseline_stats:
stats = self._baseline_stats[feature]
z = (value - stats["mean"]) / stats["std"]
z_scores[feature] = abs(z)
if abs(z) > 2: # More than 2 std deviations
anomalous_features.append(f"{feature} (z={z:.2f})")
# Overall anomaly score (average of z-scores, normalized)
if z_scores:
avg_z = sum(z_scores.values()) / len(z_scores)
max_z = max(z_scores.values())
score = min(max_z / 5, 1.0) # Normalize to 0-1
baseline_deviation = avg_z
else:
score = 0.5
baseline_deviation = 0.0
# Generate recommendation
if score > 0.8:
recommendation = "BLOCK: Highly anomalous, likely attack"
elif score > 0.5:
recommendation = "CHALLENGE: Moderately anomalous, requires verification"
elif score > 0.3:
recommendation = "LOG: Slightly unusual, monitor closely"
else:
recommendation = "ALLOW: Within normal parameters"
return AnomalyScore(
score=score,
baseline_deviation=baseline_deviation,
anomalous_features=anomalous_features,
recommendation=recommendation
)
def _heuristic_score(self, features: Dict[str, float]) -> AnomalyScore:
"""Score based on heuristics when no baseline exists."""
score = 0.0
anomalous_features: List[str] = []
# Check for attack indicators
for attack_type in ["sqli_score", "xss_score", "rce_score", "path_traversal_score"]:
if features.get(attack_type, 0) > 0.3:
score += 0.25
anomalous_features.append(attack_type)
# Check for suspicious characteristics
if features.get("special_char_ratio", 0) > 0.15:
score += 0.15
anomalous_features.append("high_special_chars")
if features.get("entropy", 0) > 0.8:
score += 0.1
anomalous_features.append("high_entropy")
score = min(score, 1.0)
if score > 0.7:
recommendation = "BLOCK: Multiple attack indicators"
elif score > 0.4:
recommendation = "CHALLENGE: Suspicious characteristics"
else:
recommendation = "ALLOW: No obvious threats"
return AnomalyScore(
score=score,
baseline_deviation=0.0,
anomalous_features=anomalous_features,
recommendation=recommendation
)
class ThreatClassifier:
"""
High-level threat classifier combining multiple techniques.
Usage:
classifier = ThreatClassifier()
result = classifier.classify("' OR '1'='1")
print(f"Label: {result.label}, Confidence: {result.confidence}")
"""
def __init__(self, model_path: Optional[Path] = None):
self.naive_bayes = NaiveBayesClassifier()
self.anomaly_detector = AnomalyDetector()
self.model_path = model_path
# Train on startup
self.naive_bayes.train()
def classify(self, text: str) -> ClassificationResult:
"""Classify a request/pattern."""
return self.naive_bayes.classify(text)
def score_anomaly(self, text: str) -> AnomalyScore:
"""Score how anomalous a request is."""
return self.anomaly_detector.score(text)
def analyze(self, text: str) -> Dict[str, Any]:
"""Full analysis combining classification and anomaly detection."""
classification = self.classify(text)
anomaly = self.score_anomaly(text)
return {
"classification": {
"label": classification.label,
"confidence": classification.confidence,
"probabilities": classification.probabilities,
"explanation": classification.explanation,
},
"anomaly": {
"score": anomaly.score,
"baseline_deviation": anomaly.baseline_deviation,
"anomalous_features": anomaly.anomalous_features,
"recommendation": anomaly.recommendation,
},
"risk_level": self._compute_risk_level(classification, anomaly),
}
def _compute_risk_level(
self,
classification: ClassificationResult,
anomaly: AnomalyScore
) -> str:
"""Compute overall risk level."""
# High-risk attack types
high_risk_labels = {"sqli", "xss", "rce"}
if classification.label in high_risk_labels and classification.confidence > 0.7:
return "critical"
if classification.label in high_risk_labels and classification.confidence > 0.4:
return "high"
if anomaly.score > 0.7:
return "high"
if classification.label == "scanner":
return "medium"
if anomaly.score > 0.4:
return "medium"
return "low"
# CLI for testing
if __name__ == "__main__":
import sys
classifier = ThreatClassifier()
test_inputs = [
"' OR '1'='1",
"",
"; cat /etc/passwd",
"../../../etc/passwd",
"Mozilla/5.0 (Windows NT 10.0)",
"/api/users/123",
]
if len(sys.argv) > 1:
test_inputs = sys.argv[1:]
print("\n🤖 ML Threat Classifier Test")
print("=" * 60)
for text in test_inputs:
result = classifier.analyze(text)
print(f"\nInput: {text[:50]}...")
print(f" Label: {result['classification']['label']}")
print(f" Confidence: {result['classification']['confidence']:.2%}")
print(f" Risk Level: {result['risk_level'].upper()}")
print(f" Anomaly Score: {result['anomaly']['score']:.2%}")
print(f" Recommendation: {result['anomaly']['recommendation']}")