""" Phase 7: ML-Based Threat Classifier Uses simple but effective ML techniques for: - Attack pattern classification (SQLi, XSS, RCE, etc.) - Anomaly scoring based on request features - Risk-level prediction for proposed rules Designed to work offline without heavy dependencies. Uses scikit-learn-style interface but can run with pure Python fallback. """ from __future__ import annotations import hashlib import json import math import re from collections import Counter, defaultdict from dataclasses import dataclass, field from pathlib import Path from typing import Any, Dict, List, Optional, Set, Tuple from layer0 import layer0_entry from layer0.shadow_classifier import ShadowEvalResult # Try to import sklearn, fall back to pure Python try: from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.naive_bayes import MultinomialNB from sklearn.preprocessing import LabelEncoder HAS_SKLEARN = True except ImportError: HAS_SKLEARN = False @dataclass class ClassificationResult: """Result of classifying a threat indicator or pattern.""" label: str # "sqli", "xss", "rce", "path_traversal", "scanner", "benign", etc. confidence: float # 0.0-1.0 probabilities: Dict[str, float] = field(default_factory=dict) features_used: List[str] = field(default_factory=list) explanation: str = "" @dataclass class AnomalyScore: """Anomaly detection result.""" score: float # 0.0-1.0 (higher = more anomalous) baseline_deviation: float # standard deviations from mean anomalous_features: List[str] = field(default_factory=list) recommendation: str = "" class FeatureExtractor: """Extract features from request/log data for ML classification.""" # Character distribution features SPECIAL_CHARS = set("'\"<>(){}[];=&|`$\\") # Known attack signatures for feature detection SQLI_PATTERNS = [ r"(?i)union\s+select", r"(?i)select\s+.*\s+from", r"(?i)insert\s+into", r"(?i)update\s+.*\s+set", r"(?i)delete\s+from", r"(?i)drop\s+table", r"(?i);\s*--", r"(?i)'\s*or\s+'?1'?\s*=\s*'?1", r"(?i)'\s*and\s+'?1'?\s*=\s*'?1", ] XSS_PATTERNS = [ r"(?i) Dict[str, float]: """Extract numerical features from text.""" features: Dict[str, float] = {} if not text: return features text_lower = text.lower() text_len = len(text) # Length features features["length"] = min(text_len / 1000, 1.0) # normalized features["length_log"] = math.log1p(text_len) / 10 # Character distribution special_count = sum(1 for c in text if c in self.SPECIAL_CHARS) features["special_char_ratio"] = special_count / max(text_len, 1) features["uppercase_ratio"] = sum(1 for c in text if c.isupper()) / max(text_len, 1) features["digit_ratio"] = sum(1 for c in text if c.isdigit()) / max(text_len, 1) # Entropy (randomness indicator) features["entropy"] = self._calculate_entropy(text) # Pattern-based features features["sqli_score"] = self._pattern_score(text, self.SQLI_PATTERNS) features["xss_score"] = self._pattern_score(text, self.XSS_PATTERNS) features["rce_score"] = self._pattern_score(text, self.RCE_PATTERNS) features["path_traversal_score"] = self._pattern_score(text, self.PATH_TRAVERSAL_PATTERNS) # Structural features features["quote_count"] = (text.count("'") + text.count('"')) / max(text_len, 1) features["paren_count"] = (text.count("(") + text.count(")")) / max(text_len, 1) features["bracket_count"] = (text.count("[") + text.count("]") + text.count("{") + text.count("}")) / max(text_len, 1) # Keyword presence features["has_select"] = 1.0 if "select" in text_lower else 0.0 features["has_script"] = 1.0 if " float: """Calculate Shannon entropy of text.""" if not text: return 0.0 freq = Counter(text) length = len(text) entropy = 0.0 for count in freq.values(): prob = count / length if prob > 0: entropy -= prob * math.log2(prob) # Normalize to 0-1 range (max entropy for ASCII is ~7) return min(entropy / 7, 1.0) def _pattern_score(self, text: str, patterns: List[str]) -> float: """Calculate pattern match score.""" matches = sum(1 for p in patterns if re.search(p, text)) return min(matches / max(len(patterns), 1), 1.0) class NaiveBayesClassifier: """ Simple Naive Bayes classifier for attack type classification. Works with or without sklearn. """ LABELS = ["sqli", "xss", "rce", "path_traversal", "scanner", "benign"] def __init__(self): self.feature_extractor = FeatureExtractor() self._trained = False # Training data (curated examples) self._training_data = self._get_training_data() # Feature statistics per class (for pure Python implementation) self._class_priors: Dict[str, float] = {} self._feature_means: Dict[str, Dict[str, float]] = defaultdict(dict) self._feature_vars: Dict[str, Dict[str, float]] = defaultdict(dict) def _get_training_data(self) -> List[Tuple[str, str]]: """Return curated training examples.""" return [ # SQLi examples ("' OR '1'='1", "sqli"), ("1; DROP TABLE users--", "sqli"), ("UNION SELECT * FROM passwords", "sqli"), ("admin'--", "sqli"), ("1' AND 1=1--", "sqli"), ("'; INSERT INTO users VALUES('hack','hack')--", "sqli"), # XSS examples ("", "xss"), ("", "xss"), ("javascript:alert(document.cookie)", "xss"), ("", "xss"), ("'\">", "xss"), # RCE examples ("; cat /etc/passwd", "rce"), ("| ls -la", "rce"), ("`id`", "rce"), ("$(whoami)", "rce"), ("; rm -rf /", "rce"), ("system('cat /etc/passwd')", "rce"), # Path traversal ("../../../etc/passwd", "path_traversal"), ("..\\..\\..\\windows\\system32\\config\\sam", "path_traversal"), ("/etc/passwd%00", "path_traversal"), ("....//....//etc/passwd", "path_traversal"), # Scanner signatures ("Mozilla/5.0 (compatible; Nmap Scripting Engine)", "scanner"), ("sqlmap/1.0", "scanner"), ("Nikto/2.1.5", "scanner"), ("masscan/1.0", "scanner"), # Benign examples ("/api/users/123", "benign"), ("Mozilla/5.0 (Windows NT 10.0; Win64; x64)", "benign"), ("/products?category=electronics&page=2", "benign"), ("GET /index.html HTTP/1.1", "benign"), ("/static/css/main.css", "benign"), ] def train(self) -> None: """Train the classifier on built-in examples.""" # Extract features for all training data X: List[Dict[str, float]] = [] y: List[str] = [] for text, label in self._training_data: features = self.feature_extractor.extract(text) X.append(features) y.append(label) # Calculate class priors label_counts = Counter(y) total = len(y) for label, count in label_counts.items(): self._class_priors[label] = count / total # Calculate feature means and variances per class all_features = set() for features in X: all_features.update(features.keys()) for label in self.LABELS: class_features = [X[i] for i in range(len(X)) if y[i] == label] if not class_features: continue for feature in all_features: values = [f.get(feature, 0.0) for f in class_features] mean = sum(values) / len(values) var = sum((v - mean) ** 2 for v in values) / len(values) self._feature_means[label][feature] = mean self._feature_vars[label][feature] = max(var, 1e-6) # avoid division by zero self._trained = True def classify(self, text: str) -> ClassificationResult: """Classify text into attack category.""" if not self._trained: self.train() features = self.feature_extractor.extract(text) # Calculate log probabilities for each class log_probs: Dict[str, float] = {} for label in self.LABELS: if label not in self._class_priors: continue log_prob = math.log(self._class_priors[label]) for feature, value in features.items(): if feature in self._feature_means[label]: mean = self._feature_means[label][feature] var = self._feature_vars[label][feature] # Gaussian likelihood log_prob += -0.5 * math.log(2 * math.pi * var) log_prob += -0.5 * ((value - mean) ** 2) / var log_probs[label] = log_prob # Convert to probabilities via softmax max_log_prob = max(log_probs.values()) if log_probs else 0 exp_probs = {k: math.exp(v - max_log_prob) for k, v in log_probs.items()} total = sum(exp_probs.values()) probs = {k: v / total for k, v in exp_probs.items()} # Find best label best_label = max(probs, key=probs.get) if probs else "benign" confidence = probs.get(best_label, 0.0) # Generate explanation explanation = self._generate_explanation(text, features, best_label) return ClassificationResult( label=best_label, confidence=confidence, probabilities=probs, features_used=list(features.keys()), explanation=explanation ) def _generate_explanation(self, text: str, features: Dict[str, float], label: str) -> str: """Generate human-readable explanation for classification.""" reasons = [] if features.get("sqli_score", 0) > 0.3: reasons.append("SQL injection patterns detected") if features.get("xss_score", 0) > 0.3: reasons.append("XSS patterns detected") if features.get("rce_score", 0) > 0.3: reasons.append("Command injection patterns detected") if features.get("path_traversal_score", 0) > 0.3: reasons.append("Path traversal patterns detected") if features.get("special_char_ratio", 0) > 0.2: reasons.append("High special character ratio") if features.get("entropy", 0) > 0.7: reasons.append("High entropy (possible encoding/obfuscation)") if not reasons: reasons.append(f"General pattern matching suggests {label}") return "; ".join(reasons) class AnomalyDetector: """ Detect anomalous requests based on baseline behavior. Uses statistical methods (z-score, IQR) without requiring ML libraries. """ def __init__(self): self.feature_extractor = FeatureExtractor() self._baseline_stats: Dict[str, Dict[str, float]] = {} self._observations: List[Dict[str, float]] = [] def add_observation(self, text: str) -> None: """Add an observation to the baseline.""" features = self.feature_extractor.extract(text) self._observations.append(features) # Recalculate baseline after enough observations if len(self._observations) >= 10: self._update_baseline() def _update_baseline(self) -> None: """Update baseline statistics.""" if not self._observations: return all_features = set() for obs in self._observations: all_features.update(obs.keys()) for feature in all_features: values = [obs.get(feature, 0.0) for obs in self._observations] mean = sum(values) / len(values) var = sum((v - mean) ** 2 for v in values) / len(values) std = math.sqrt(var) if var > 0 else 0.001 self._baseline_stats[feature] = { "mean": mean, "std": std, "min": min(values), "max": max(values), } def score(self, text: str) -> AnomalyScore: """Score how anomalous a request is.""" features = self.feature_extractor.extract(text) if not self._baseline_stats: # No baseline yet, use heuristics return self._heuristic_score(features) z_scores: Dict[str, float] = {} anomalous_features: List[str] = [] for feature, value in features.items(): if feature in self._baseline_stats: stats = self._baseline_stats[feature] z = (value - stats["mean"]) / stats["std"] z_scores[feature] = abs(z) if abs(z) > 2: # More than 2 std deviations anomalous_features.append(f"{feature} (z={z:.2f})") # Overall anomaly score (average of z-scores, normalized) if z_scores: avg_z = sum(z_scores.values()) / len(z_scores) max_z = max(z_scores.values()) score = min(max_z / 5, 1.0) # Normalize to 0-1 baseline_deviation = avg_z else: score = 0.5 baseline_deviation = 0.0 # Generate recommendation if score > 0.8: recommendation = "BLOCK: Highly anomalous, likely attack" elif score > 0.5: recommendation = "CHALLENGE: Moderately anomalous, requires verification" elif score > 0.3: recommendation = "LOG: Slightly unusual, monitor closely" else: recommendation = "ALLOW: Within normal parameters" return AnomalyScore( score=score, baseline_deviation=baseline_deviation, anomalous_features=anomalous_features, recommendation=recommendation ) def _heuristic_score(self, features: Dict[str, float]) -> AnomalyScore: """Score based on heuristics when no baseline exists.""" score = 0.0 anomalous_features: List[str] = [] # Check for attack indicators for attack_type in ["sqli_score", "xss_score", "rce_score", "path_traversal_score"]: if features.get(attack_type, 0) > 0.3: score += 0.25 anomalous_features.append(attack_type) # Check for suspicious characteristics if features.get("special_char_ratio", 0) > 0.15: score += 0.15 anomalous_features.append("high_special_chars") if features.get("entropy", 0) > 0.8: score += 0.1 anomalous_features.append("high_entropy") score = min(score, 1.0) if score > 0.7: recommendation = "BLOCK: Multiple attack indicators" elif score > 0.4: recommendation = "CHALLENGE: Suspicious characteristics" else: recommendation = "ALLOW: No obvious threats" return AnomalyScore( score=score, baseline_deviation=0.0, anomalous_features=anomalous_features, recommendation=recommendation ) class ThreatClassifier: """ High-level threat classifier combining multiple techniques. Usage: classifier = ThreatClassifier() result = classifier.classify("' OR '1'='1") print(f"Label: {result.label}, Confidence: {result.confidence}") """ def __init__(self, model_path: Optional[Path] = None): self.naive_bayes = NaiveBayesClassifier() self.anomaly_detector = AnomalyDetector() self.model_path = model_path # Train on startup self.naive_bayes.train() def classify(self, text: str) -> ClassificationResult: """Classify a request/pattern.""" return self.naive_bayes.classify(text) def score_anomaly(self, text: str) -> AnomalyScore: """Score how anomalous a request is.""" return self.anomaly_detector.score(text) def analyze(self, text: str) -> Dict[str, Any]: """Full analysis combining classification and anomaly detection.""" classification = self.classify(text) anomaly = self.score_anomaly(text) return { "classification": { "label": classification.label, "confidence": classification.confidence, "probabilities": classification.probabilities, "explanation": classification.explanation, }, "anomaly": { "score": anomaly.score, "baseline_deviation": anomaly.baseline_deviation, "anomalous_features": anomaly.anomalous_features, "recommendation": anomaly.recommendation, }, "risk_level": self._compute_risk_level(classification, anomaly), } def _compute_risk_level( self, classification: ClassificationResult, anomaly: AnomalyScore ) -> str: """Compute overall risk level.""" # High-risk attack types high_risk_labels = {"sqli", "xss", "rce"} if classification.label in high_risk_labels and classification.confidence > 0.7: return "critical" if classification.label in high_risk_labels and classification.confidence > 0.4: return "high" if anomaly.score > 0.7: return "high" if classification.label == "scanner": return "medium" if anomaly.score > 0.4: return "medium" return "low" # CLI for testing if __name__ == "__main__": import sys classifier = ThreatClassifier() test_inputs = [ "' OR '1'='1", "", "; cat /etc/passwd", "../../../etc/passwd", "Mozilla/5.0 (Windows NT 10.0)", "/api/users/123", ] if len(sys.argv) > 1: test_inputs = sys.argv[1:] print("\n🤖 ML Threat Classifier Test") print("=" * 60) for text in test_inputs: routing_action, shadow = layer0_entry(text) if routing_action != "HANDOFF_TO_LAYER1": print(_layer0_cli_msg(routing_action, shadow), file=sys.stderr) continue result = classifier.analyze(text) print(f"\nInput: {text[:50]}...") print(f" Label: {result['classification']['label']}") print(f" Confidence: {result['classification']['confidence']:.2%}") print(f" Risk Level: {result['risk_level'].upper()}") print(f" Anomaly Score: {result['anomaly']['score']:.2%}") print(f" Recommendation: {result['anomaly']['recommendation']}") def _layer0_cli_msg(routing_action: str, shadow: ShadowEvalResult) -> str: if routing_action == "FAIL_CLOSED": return "Layer 0: cannot comply with this request." if routing_action == "HANDOFF_TO_GUARDRAILS": reason = shadow.reason or "governance_violation" return f"Layer 0: governance violation detected ({reason})." if routing_action == "PROMPT_FOR_CLARIFICATION": return "Layer 0: request is ambiguous. Please add specifics before rerunning." return "Layer 0: unrecognized routing action; refusing request."