Files
vm-cloudflare/mcp/waf_intelligence/classifier.py
2025-12-17 00:02:39 +00:00

584 lines
21 KiB
Python

"""
Phase 7: ML-Based Threat Classifier
Uses simple but effective ML techniques for:
- Attack pattern classification (SQLi, XSS, RCE, etc.)
- Anomaly scoring based on request features
- Risk-level prediction for proposed rules
Designed to work offline without heavy dependencies.
Uses scikit-learn-style interface but can run with pure Python fallback.
"""
from __future__ import annotations
import hashlib
import json
import math
import re
from collections import Counter, defaultdict
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Tuple
from layer0 import layer0_entry
from layer0.shadow_classifier import ShadowEvalResult
# Try to import sklearn, fall back to pure Python
try:
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.naive_bayes import MultinomialNB
from sklearn.preprocessing import LabelEncoder
HAS_SKLEARN = True
except ImportError:
HAS_SKLEARN = False
@dataclass
class ClassificationResult:
"""Result of classifying a threat indicator or pattern."""
label: str # "sqli", "xss", "rce", "path_traversal", "scanner", "benign", etc.
confidence: float # 0.0-1.0
probabilities: Dict[str, float] = field(default_factory=dict)
features_used: List[str] = field(default_factory=list)
explanation: str = ""
@dataclass
class AnomalyScore:
"""Anomaly detection result."""
score: float # 0.0-1.0 (higher = more anomalous)
baseline_deviation: float # standard deviations from mean
anomalous_features: List[str] = field(default_factory=list)
recommendation: str = ""
class FeatureExtractor:
"""Extract features from request/log data for ML classification."""
# Character distribution features
SPECIAL_CHARS = set("'\"<>(){}[];=&|`$\\")
# Known attack signatures for feature detection
SQLI_PATTERNS = [
r"(?i)union\s+select",
r"(?i)select\s+.*\s+from",
r"(?i)insert\s+into",
r"(?i)update\s+.*\s+set",
r"(?i)delete\s+from",
r"(?i)drop\s+table",
r"(?i);\s*--",
r"(?i)'\s*or\s+'?1'?\s*=\s*'?1",
r"(?i)'\s*and\s+'?1'?\s*=\s*'?1",
]
XSS_PATTERNS = [
r"(?i)<script",
r"(?i)javascript:",
r"(?i)on\w+\s*=",
r"(?i)alert\s*\(",
r"(?i)document\.",
r"(?i)window\.",
r"(?i)eval\s*\(",
]
RCE_PATTERNS = [
r"(?i);\s*(?:cat|ls|id|whoami|pwd)",
r"(?i)\|\s*(?:cat|ls|id|whoami)",
r"(?i)`[^`]+`",
r"(?i)\$\([^)]+\)",
r"(?i)system\s*\(",
r"(?i)exec\s*\(",
r"(?i)passthru\s*\(",
]
PATH_TRAVERSAL_PATTERNS = [
r"\.\./",
r"\.\.\\",
r"(?i)etc/passwd",
r"(?i)windows/system32",
]
def extract(self, text: str) -> Dict[str, float]:
"""Extract numerical features from text."""
features: Dict[str, float] = {}
if not text:
return features
text_lower = text.lower()
text_len = len(text)
# Length features
features["length"] = min(text_len / 1000, 1.0) # normalized
features["length_log"] = math.log1p(text_len) / 10
# Character distribution
special_count = sum(1 for c in text if c in self.SPECIAL_CHARS)
features["special_char_ratio"] = special_count / max(text_len, 1)
features["uppercase_ratio"] = sum(1 for c in text if c.isupper()) / max(text_len, 1)
features["digit_ratio"] = sum(1 for c in text if c.isdigit()) / max(text_len, 1)
# Entropy (randomness indicator)
features["entropy"] = self._calculate_entropy(text)
# Pattern-based features
features["sqli_score"] = self._pattern_score(text, self.SQLI_PATTERNS)
features["xss_score"] = self._pattern_score(text, self.XSS_PATTERNS)
features["rce_score"] = self._pattern_score(text, self.RCE_PATTERNS)
features["path_traversal_score"] = self._pattern_score(text, self.PATH_TRAVERSAL_PATTERNS)
# Structural features
features["quote_count"] = (text.count("'") + text.count('"')) / max(text_len, 1)
features["paren_count"] = (text.count("(") + text.count(")")) / max(text_len, 1)
features["bracket_count"] = (text.count("[") + text.count("]") + text.count("{") + text.count("}")) / max(text_len, 1)
# Keyword presence
features["has_select"] = 1.0 if "select" in text_lower else 0.0
features["has_script"] = 1.0 if "<script" in text_lower else 0.0
features["has_etc_passwd"] = 1.0 if "etc/passwd" in text_lower else 0.0
return features
def _calculate_entropy(self, text: str) -> float:
"""Calculate Shannon entropy of text."""
if not text:
return 0.0
freq = Counter(text)
length = len(text)
entropy = 0.0
for count in freq.values():
prob = count / length
if prob > 0:
entropy -= prob * math.log2(prob)
# Normalize to 0-1 range (max entropy for ASCII is ~7)
return min(entropy / 7, 1.0)
def _pattern_score(self, text: str, patterns: List[str]) -> float:
"""Calculate pattern match score."""
matches = sum(1 for p in patterns if re.search(p, text))
return min(matches / max(len(patterns), 1), 1.0)
class NaiveBayesClassifier:
"""
Simple Naive Bayes classifier for attack type classification.
Works with or without sklearn.
"""
LABELS = ["sqli", "xss", "rce", "path_traversal", "scanner", "benign"]
def __init__(self):
self.feature_extractor = FeatureExtractor()
self._trained = False
# Training data (curated examples)
self._training_data = self._get_training_data()
# Feature statistics per class (for pure Python implementation)
self._class_priors: Dict[str, float] = {}
self._feature_means: Dict[str, Dict[str, float]] = defaultdict(dict)
self._feature_vars: Dict[str, Dict[str, float]] = defaultdict(dict)
def _get_training_data(self) -> List[Tuple[str, str]]:
"""Return curated training examples."""
return [
# SQLi examples
("' OR '1'='1", "sqli"),
("1; DROP TABLE users--", "sqli"),
("UNION SELECT * FROM passwords", "sqli"),
("admin'--", "sqli"),
("1' AND 1=1--", "sqli"),
("'; INSERT INTO users VALUES('hack','hack')--", "sqli"),
# XSS examples
("<script>alert('xss')</script>", "xss"),
("<img src=x onerror=alert(1)>", "xss"),
("javascript:alert(document.cookie)", "xss"),
("<svg onload=alert(1)>", "xss"),
("'\"><script>alert('XSS')</script>", "xss"),
# RCE examples
("; cat /etc/passwd", "rce"),
("| ls -la", "rce"),
("`id`", "rce"),
("$(whoami)", "rce"),
("; rm -rf /", "rce"),
("system('cat /etc/passwd')", "rce"),
# Path traversal
("../../../etc/passwd", "path_traversal"),
("..\\..\\..\\windows\\system32\\config\\sam", "path_traversal"),
("/etc/passwd%00", "path_traversal"),
("....//....//etc/passwd", "path_traversal"),
# Scanner signatures
("Mozilla/5.0 (compatible; Nmap Scripting Engine)", "scanner"),
("sqlmap/1.0", "scanner"),
("Nikto/2.1.5", "scanner"),
("masscan/1.0", "scanner"),
# Benign examples
("/api/users/123", "benign"),
("Mozilla/5.0 (Windows NT 10.0; Win64; x64)", "benign"),
("/products?category=electronics&page=2", "benign"),
("GET /index.html HTTP/1.1", "benign"),
("/static/css/main.css", "benign"),
]
def train(self) -> None:
"""Train the classifier on built-in examples."""
# Extract features for all training data
X: List[Dict[str, float]] = []
y: List[str] = []
for text, label in self._training_data:
features = self.feature_extractor.extract(text)
X.append(features)
y.append(label)
# Calculate class priors
label_counts = Counter(y)
total = len(y)
for label, count in label_counts.items():
self._class_priors[label] = count / total
# Calculate feature means and variances per class
all_features = set()
for features in X:
all_features.update(features.keys())
for label in self.LABELS:
class_features = [X[i] for i in range(len(X)) if y[i] == label]
if not class_features:
continue
for feature in all_features:
values = [f.get(feature, 0.0) for f in class_features]
mean = sum(values) / len(values)
var = sum((v - mean) ** 2 for v in values) / len(values)
self._feature_means[label][feature] = mean
self._feature_vars[label][feature] = max(var, 1e-6) # avoid division by zero
self._trained = True
def classify(self, text: str) -> ClassificationResult:
"""Classify text into attack category."""
if not self._trained:
self.train()
features = self.feature_extractor.extract(text)
# Calculate log probabilities for each class
log_probs: Dict[str, float] = {}
for label in self.LABELS:
if label not in self._class_priors:
continue
log_prob = math.log(self._class_priors[label])
for feature, value in features.items():
if feature in self._feature_means[label]:
mean = self._feature_means[label][feature]
var = self._feature_vars[label][feature]
# Gaussian likelihood
log_prob += -0.5 * math.log(2 * math.pi * var)
log_prob += -0.5 * ((value - mean) ** 2) / var
log_probs[label] = log_prob
# Convert to probabilities via softmax
max_log_prob = max(log_probs.values()) if log_probs else 0
exp_probs = {k: math.exp(v - max_log_prob) for k, v in log_probs.items()}
total = sum(exp_probs.values())
probs = {k: v / total for k, v in exp_probs.items()}
# Find best label
best_label = max(probs, key=probs.get) if probs else "benign"
confidence = probs.get(best_label, 0.0)
# Generate explanation
explanation = self._generate_explanation(text, features, best_label)
return ClassificationResult(
label=best_label,
confidence=confidence,
probabilities=probs,
features_used=list(features.keys()),
explanation=explanation
)
def _generate_explanation(self, text: str, features: Dict[str, float], label: str) -> str:
"""Generate human-readable explanation for classification."""
reasons = []
if features.get("sqli_score", 0) > 0.3:
reasons.append("SQL injection patterns detected")
if features.get("xss_score", 0) > 0.3:
reasons.append("XSS patterns detected")
if features.get("rce_score", 0) > 0.3:
reasons.append("Command injection patterns detected")
if features.get("path_traversal_score", 0) > 0.3:
reasons.append("Path traversal patterns detected")
if features.get("special_char_ratio", 0) > 0.2:
reasons.append("High special character ratio")
if features.get("entropy", 0) > 0.7:
reasons.append("High entropy (possible encoding/obfuscation)")
if not reasons:
reasons.append(f"General pattern matching suggests {label}")
return "; ".join(reasons)
class AnomalyDetector:
"""
Detect anomalous requests based on baseline behavior.
Uses statistical methods (z-score, IQR) without requiring ML libraries.
"""
def __init__(self):
self.feature_extractor = FeatureExtractor()
self._baseline_stats: Dict[str, Dict[str, float]] = {}
self._observations: List[Dict[str, float]] = []
def add_observation(self, text: str) -> None:
"""Add an observation to the baseline."""
features = self.feature_extractor.extract(text)
self._observations.append(features)
# Recalculate baseline after enough observations
if len(self._observations) >= 10:
self._update_baseline()
def _update_baseline(self) -> None:
"""Update baseline statistics."""
if not self._observations:
return
all_features = set()
for obs in self._observations:
all_features.update(obs.keys())
for feature in all_features:
values = [obs.get(feature, 0.0) for obs in self._observations]
mean = sum(values) / len(values)
var = sum((v - mean) ** 2 for v in values) / len(values)
std = math.sqrt(var) if var > 0 else 0.001
self._baseline_stats[feature] = {
"mean": mean,
"std": std,
"min": min(values),
"max": max(values),
}
def score(self, text: str) -> AnomalyScore:
"""Score how anomalous a request is."""
features = self.feature_extractor.extract(text)
if not self._baseline_stats:
# No baseline yet, use heuristics
return self._heuristic_score(features)
z_scores: Dict[str, float] = {}
anomalous_features: List[str] = []
for feature, value in features.items():
if feature in self._baseline_stats:
stats = self._baseline_stats[feature]
z = (value - stats["mean"]) / stats["std"]
z_scores[feature] = abs(z)
if abs(z) > 2: # More than 2 std deviations
anomalous_features.append(f"{feature} (z={z:.2f})")
# Overall anomaly score (average of z-scores, normalized)
if z_scores:
avg_z = sum(z_scores.values()) / len(z_scores)
max_z = max(z_scores.values())
score = min(max_z / 5, 1.0) # Normalize to 0-1
baseline_deviation = avg_z
else:
score = 0.5
baseline_deviation = 0.0
# Generate recommendation
if score > 0.8:
recommendation = "BLOCK: Highly anomalous, likely attack"
elif score > 0.5:
recommendation = "CHALLENGE: Moderately anomalous, requires verification"
elif score > 0.3:
recommendation = "LOG: Slightly unusual, monitor closely"
else:
recommendation = "ALLOW: Within normal parameters"
return AnomalyScore(
score=score,
baseline_deviation=baseline_deviation,
anomalous_features=anomalous_features,
recommendation=recommendation
)
def _heuristic_score(self, features: Dict[str, float]) -> AnomalyScore:
"""Score based on heuristics when no baseline exists."""
score = 0.0
anomalous_features: List[str] = []
# Check for attack indicators
for attack_type in ["sqli_score", "xss_score", "rce_score", "path_traversal_score"]:
if features.get(attack_type, 0) > 0.3:
score += 0.25
anomalous_features.append(attack_type)
# Check for suspicious characteristics
if features.get("special_char_ratio", 0) > 0.15:
score += 0.15
anomalous_features.append("high_special_chars")
if features.get("entropy", 0) > 0.8:
score += 0.1
anomalous_features.append("high_entropy")
score = min(score, 1.0)
if score > 0.7:
recommendation = "BLOCK: Multiple attack indicators"
elif score > 0.4:
recommendation = "CHALLENGE: Suspicious characteristics"
else:
recommendation = "ALLOW: No obvious threats"
return AnomalyScore(
score=score,
baseline_deviation=0.0,
anomalous_features=anomalous_features,
recommendation=recommendation
)
class ThreatClassifier:
"""
High-level threat classifier combining multiple techniques.
Usage:
classifier = ThreatClassifier()
result = classifier.classify("' OR '1'='1")
print(f"Label: {result.label}, Confidence: {result.confidence}")
"""
def __init__(self, model_path: Optional[Path] = None):
self.naive_bayes = NaiveBayesClassifier()
self.anomaly_detector = AnomalyDetector()
self.model_path = model_path
# Train on startup
self.naive_bayes.train()
def classify(self, text: str) -> ClassificationResult:
"""Classify a request/pattern."""
return self.naive_bayes.classify(text)
def score_anomaly(self, text: str) -> AnomalyScore:
"""Score how anomalous a request is."""
return self.anomaly_detector.score(text)
def analyze(self, text: str) -> Dict[str, Any]:
"""Full analysis combining classification and anomaly detection."""
classification = self.classify(text)
anomaly = self.score_anomaly(text)
return {
"classification": {
"label": classification.label,
"confidence": classification.confidence,
"probabilities": classification.probabilities,
"explanation": classification.explanation,
},
"anomaly": {
"score": anomaly.score,
"baseline_deviation": anomaly.baseline_deviation,
"anomalous_features": anomaly.anomalous_features,
"recommendation": anomaly.recommendation,
},
"risk_level": self._compute_risk_level(classification, anomaly),
}
def _compute_risk_level(
self,
classification: ClassificationResult,
anomaly: AnomalyScore
) -> str:
"""Compute overall risk level."""
# High-risk attack types
high_risk_labels = {"sqli", "xss", "rce"}
if classification.label in high_risk_labels and classification.confidence > 0.7:
return "critical"
if classification.label in high_risk_labels and classification.confidence > 0.4:
return "high"
if anomaly.score > 0.7:
return "high"
if classification.label == "scanner":
return "medium"
if anomaly.score > 0.4:
return "medium"
return "low"
# CLI for testing
if __name__ == "__main__":
import sys
classifier = ThreatClassifier()
test_inputs = [
"' OR '1'='1",
"<script>alert('xss')</script>",
"; cat /etc/passwd",
"../../../etc/passwd",
"Mozilla/5.0 (Windows NT 10.0)",
"/api/users/123",
]
if len(sys.argv) > 1:
test_inputs = sys.argv[1:]
print("\n🤖 ML Threat Classifier Test")
print("=" * 60)
for text in test_inputs:
routing_action, shadow = layer0_entry(text)
if routing_action != "HANDOFF_TO_LAYER1":
print(_layer0_cli_msg(routing_action, shadow), file=sys.stderr)
continue
result = classifier.analyze(text)
print(f"\nInput: {text[:50]}...")
print(f" Label: {result['classification']['label']}")
print(f" Confidence: {result['classification']['confidence']:.2%}")
print(f" Risk Level: {result['risk_level'].upper()}")
print(f" Anomaly Score: {result['anomaly']['score']:.2%}")
print(f" Recommendation: {result['anomaly']['recommendation']}")
def _layer0_cli_msg(routing_action: str, shadow: ShadowEvalResult) -> str:
if routing_action == "FAIL_CLOSED":
return "Layer 0: cannot comply with this request."
if routing_action == "HANDOFF_TO_GUARDRAILS":
reason = shadow.reason or "governance_violation"
return f"Layer 0: governance violation detected ({reason})."
if routing_action == "PROMPT_FOR_CLARIFICATION":
return "Layer 0: request is ambiguous. Please add specifics before rerunning."
return "Layer 0: unrecognized routing action; refusing request."