584 lines
21 KiB
Python
584 lines
21 KiB
Python
"""
|
|
Phase 7: ML-Based Threat Classifier
|
|
|
|
Uses simple but effective ML techniques for:
|
|
- Attack pattern classification (SQLi, XSS, RCE, etc.)
|
|
- Anomaly scoring based on request features
|
|
- Risk-level prediction for proposed rules
|
|
|
|
Designed to work offline without heavy dependencies.
|
|
Uses scikit-learn-style interface but can run with pure Python fallback.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import hashlib
|
|
import json
|
|
import math
|
|
import re
|
|
from collections import Counter, defaultdict
|
|
from dataclasses import dataclass, field
|
|
from pathlib import Path
|
|
from typing import Any, Dict, List, Optional, Set, Tuple
|
|
|
|
from layer0 import layer0_entry
|
|
from layer0.shadow_classifier import ShadowEvalResult
|
|
|
|
# Try to import sklearn, fall back to pure Python
|
|
try:
|
|
from sklearn.feature_extraction.text import TfidfVectorizer
|
|
from sklearn.naive_bayes import MultinomialNB
|
|
from sklearn.preprocessing import LabelEncoder
|
|
HAS_SKLEARN = True
|
|
except ImportError:
|
|
HAS_SKLEARN = False
|
|
|
|
|
|
@dataclass
|
|
class ClassificationResult:
|
|
"""Result of classifying a threat indicator or pattern."""
|
|
|
|
label: str # "sqli", "xss", "rce", "path_traversal", "scanner", "benign", etc.
|
|
confidence: float # 0.0-1.0
|
|
probabilities: Dict[str, float] = field(default_factory=dict)
|
|
features_used: List[str] = field(default_factory=list)
|
|
explanation: str = ""
|
|
|
|
|
|
@dataclass
|
|
class AnomalyScore:
|
|
"""Anomaly detection result."""
|
|
|
|
score: float # 0.0-1.0 (higher = more anomalous)
|
|
baseline_deviation: float # standard deviations from mean
|
|
anomalous_features: List[str] = field(default_factory=list)
|
|
recommendation: str = ""
|
|
|
|
|
|
class FeatureExtractor:
|
|
"""Extract features from request/log data for ML classification."""
|
|
|
|
# Character distribution features
|
|
SPECIAL_CHARS = set("'\"<>(){}[];=&|`$\\")
|
|
|
|
# Known attack signatures for feature detection
|
|
SQLI_PATTERNS = [
|
|
r"(?i)union\s+select",
|
|
r"(?i)select\s+.*\s+from",
|
|
r"(?i)insert\s+into",
|
|
r"(?i)update\s+.*\s+set",
|
|
r"(?i)delete\s+from",
|
|
r"(?i)drop\s+table",
|
|
r"(?i);\s*--",
|
|
r"(?i)'\s*or\s+'?1'?\s*=\s*'?1",
|
|
r"(?i)'\s*and\s+'?1'?\s*=\s*'?1",
|
|
]
|
|
|
|
XSS_PATTERNS = [
|
|
r"(?i)<script",
|
|
r"(?i)javascript:",
|
|
r"(?i)on\w+\s*=",
|
|
r"(?i)alert\s*\(",
|
|
r"(?i)document\.",
|
|
r"(?i)window\.",
|
|
r"(?i)eval\s*\(",
|
|
]
|
|
|
|
RCE_PATTERNS = [
|
|
r"(?i);\s*(?:cat|ls|id|whoami|pwd)",
|
|
r"(?i)\|\s*(?:cat|ls|id|whoami)",
|
|
r"(?i)`[^`]+`",
|
|
r"(?i)\$\([^)]+\)",
|
|
r"(?i)system\s*\(",
|
|
r"(?i)exec\s*\(",
|
|
r"(?i)passthru\s*\(",
|
|
]
|
|
|
|
PATH_TRAVERSAL_PATTERNS = [
|
|
r"\.\./",
|
|
r"\.\.\\",
|
|
r"(?i)etc/passwd",
|
|
r"(?i)windows/system32",
|
|
]
|
|
|
|
def extract(self, text: str) -> Dict[str, float]:
|
|
"""Extract numerical features from text."""
|
|
features: Dict[str, float] = {}
|
|
|
|
if not text:
|
|
return features
|
|
|
|
text_lower = text.lower()
|
|
text_len = len(text)
|
|
|
|
# Length features
|
|
features["length"] = min(text_len / 1000, 1.0) # normalized
|
|
features["length_log"] = math.log1p(text_len) / 10
|
|
|
|
# Character distribution
|
|
special_count = sum(1 for c in text if c in self.SPECIAL_CHARS)
|
|
features["special_char_ratio"] = special_count / max(text_len, 1)
|
|
features["uppercase_ratio"] = sum(1 for c in text if c.isupper()) / max(text_len, 1)
|
|
features["digit_ratio"] = sum(1 for c in text if c.isdigit()) / max(text_len, 1)
|
|
|
|
# Entropy (randomness indicator)
|
|
features["entropy"] = self._calculate_entropy(text)
|
|
|
|
# Pattern-based features
|
|
features["sqli_score"] = self._pattern_score(text, self.SQLI_PATTERNS)
|
|
features["xss_score"] = self._pattern_score(text, self.XSS_PATTERNS)
|
|
features["rce_score"] = self._pattern_score(text, self.RCE_PATTERNS)
|
|
features["path_traversal_score"] = self._pattern_score(text, self.PATH_TRAVERSAL_PATTERNS)
|
|
|
|
# Structural features
|
|
features["quote_count"] = (text.count("'") + text.count('"')) / max(text_len, 1)
|
|
features["paren_count"] = (text.count("(") + text.count(")")) / max(text_len, 1)
|
|
features["bracket_count"] = (text.count("[") + text.count("]") + text.count("{") + text.count("}")) / max(text_len, 1)
|
|
|
|
# Keyword presence
|
|
features["has_select"] = 1.0 if "select" in text_lower else 0.0
|
|
features["has_script"] = 1.0 if "<script" in text_lower else 0.0
|
|
features["has_etc_passwd"] = 1.0 if "etc/passwd" in text_lower else 0.0
|
|
|
|
return features
|
|
|
|
def _calculate_entropy(self, text: str) -> float:
|
|
"""Calculate Shannon entropy of text."""
|
|
if not text:
|
|
return 0.0
|
|
|
|
freq = Counter(text)
|
|
length = len(text)
|
|
entropy = 0.0
|
|
|
|
for count in freq.values():
|
|
prob = count / length
|
|
if prob > 0:
|
|
entropy -= prob * math.log2(prob)
|
|
|
|
# Normalize to 0-1 range (max entropy for ASCII is ~7)
|
|
return min(entropy / 7, 1.0)
|
|
|
|
def _pattern_score(self, text: str, patterns: List[str]) -> float:
|
|
"""Calculate pattern match score."""
|
|
matches = sum(1 for p in patterns if re.search(p, text))
|
|
return min(matches / max(len(patterns), 1), 1.0)
|
|
|
|
|
|
class NaiveBayesClassifier:
|
|
"""
|
|
Simple Naive Bayes classifier for attack type classification.
|
|
Works with or without sklearn.
|
|
"""
|
|
|
|
LABELS = ["sqli", "xss", "rce", "path_traversal", "scanner", "benign"]
|
|
|
|
def __init__(self):
|
|
self.feature_extractor = FeatureExtractor()
|
|
self._trained = False
|
|
|
|
# Training data (curated examples)
|
|
self._training_data = self._get_training_data()
|
|
|
|
# Feature statistics per class (for pure Python implementation)
|
|
self._class_priors: Dict[str, float] = {}
|
|
self._feature_means: Dict[str, Dict[str, float]] = defaultdict(dict)
|
|
self._feature_vars: Dict[str, Dict[str, float]] = defaultdict(dict)
|
|
|
|
def _get_training_data(self) -> List[Tuple[str, str]]:
|
|
"""Return curated training examples."""
|
|
return [
|
|
# SQLi examples
|
|
("' OR '1'='1", "sqli"),
|
|
("1; DROP TABLE users--", "sqli"),
|
|
("UNION SELECT * FROM passwords", "sqli"),
|
|
("admin'--", "sqli"),
|
|
("1' AND 1=1--", "sqli"),
|
|
("'; INSERT INTO users VALUES('hack','hack')--", "sqli"),
|
|
|
|
# XSS examples
|
|
("<script>alert('xss')</script>", "xss"),
|
|
("<img src=x onerror=alert(1)>", "xss"),
|
|
("javascript:alert(document.cookie)", "xss"),
|
|
("<svg onload=alert(1)>", "xss"),
|
|
("'\"><script>alert('XSS')</script>", "xss"),
|
|
|
|
# RCE examples
|
|
("; cat /etc/passwd", "rce"),
|
|
("| ls -la", "rce"),
|
|
("`id`", "rce"),
|
|
("$(whoami)", "rce"),
|
|
("; rm -rf /", "rce"),
|
|
("system('cat /etc/passwd')", "rce"),
|
|
|
|
# Path traversal
|
|
("../../../etc/passwd", "path_traversal"),
|
|
("..\\..\\..\\windows\\system32\\config\\sam", "path_traversal"),
|
|
("/etc/passwd%00", "path_traversal"),
|
|
("....//....//etc/passwd", "path_traversal"),
|
|
|
|
# Scanner signatures
|
|
("Mozilla/5.0 (compatible; Nmap Scripting Engine)", "scanner"),
|
|
("sqlmap/1.0", "scanner"),
|
|
("Nikto/2.1.5", "scanner"),
|
|
("masscan/1.0", "scanner"),
|
|
|
|
# Benign examples
|
|
("/api/users/123", "benign"),
|
|
("Mozilla/5.0 (Windows NT 10.0; Win64; x64)", "benign"),
|
|
("/products?category=electronics&page=2", "benign"),
|
|
("GET /index.html HTTP/1.1", "benign"),
|
|
("/static/css/main.css", "benign"),
|
|
]
|
|
|
|
def train(self) -> None:
|
|
"""Train the classifier on built-in examples."""
|
|
# Extract features for all training data
|
|
X: List[Dict[str, float]] = []
|
|
y: List[str] = []
|
|
|
|
for text, label in self._training_data:
|
|
features = self.feature_extractor.extract(text)
|
|
X.append(features)
|
|
y.append(label)
|
|
|
|
# Calculate class priors
|
|
label_counts = Counter(y)
|
|
total = len(y)
|
|
for label, count in label_counts.items():
|
|
self._class_priors[label] = count / total
|
|
|
|
# Calculate feature means and variances per class
|
|
all_features = set()
|
|
for features in X:
|
|
all_features.update(features.keys())
|
|
|
|
for label in self.LABELS:
|
|
class_features = [X[i] for i in range(len(X)) if y[i] == label]
|
|
if not class_features:
|
|
continue
|
|
|
|
for feature in all_features:
|
|
values = [f.get(feature, 0.0) for f in class_features]
|
|
mean = sum(values) / len(values)
|
|
var = sum((v - mean) ** 2 for v in values) / len(values)
|
|
self._feature_means[label][feature] = mean
|
|
self._feature_vars[label][feature] = max(var, 1e-6) # avoid division by zero
|
|
|
|
self._trained = True
|
|
|
|
def classify(self, text: str) -> ClassificationResult:
|
|
"""Classify text into attack category."""
|
|
if not self._trained:
|
|
self.train()
|
|
|
|
features = self.feature_extractor.extract(text)
|
|
|
|
# Calculate log probabilities for each class
|
|
log_probs: Dict[str, float] = {}
|
|
|
|
for label in self.LABELS:
|
|
if label not in self._class_priors:
|
|
continue
|
|
|
|
log_prob = math.log(self._class_priors[label])
|
|
|
|
for feature, value in features.items():
|
|
if feature in self._feature_means[label]:
|
|
mean = self._feature_means[label][feature]
|
|
var = self._feature_vars[label][feature]
|
|
# Gaussian likelihood
|
|
log_prob += -0.5 * math.log(2 * math.pi * var)
|
|
log_prob += -0.5 * ((value - mean) ** 2) / var
|
|
|
|
log_probs[label] = log_prob
|
|
|
|
# Convert to probabilities via softmax
|
|
max_log_prob = max(log_probs.values()) if log_probs else 0
|
|
exp_probs = {k: math.exp(v - max_log_prob) for k, v in log_probs.items()}
|
|
total = sum(exp_probs.values())
|
|
probs = {k: v / total for k, v in exp_probs.items()}
|
|
|
|
# Find best label
|
|
best_label = max(probs, key=probs.get) if probs else "benign"
|
|
confidence = probs.get(best_label, 0.0)
|
|
|
|
# Generate explanation
|
|
explanation = self._generate_explanation(text, features, best_label)
|
|
|
|
return ClassificationResult(
|
|
label=best_label,
|
|
confidence=confidence,
|
|
probabilities=probs,
|
|
features_used=list(features.keys()),
|
|
explanation=explanation
|
|
)
|
|
|
|
def _generate_explanation(self, text: str, features: Dict[str, float], label: str) -> str:
|
|
"""Generate human-readable explanation for classification."""
|
|
reasons = []
|
|
|
|
if features.get("sqli_score", 0) > 0.3:
|
|
reasons.append("SQL injection patterns detected")
|
|
if features.get("xss_score", 0) > 0.3:
|
|
reasons.append("XSS patterns detected")
|
|
if features.get("rce_score", 0) > 0.3:
|
|
reasons.append("Command injection patterns detected")
|
|
if features.get("path_traversal_score", 0) > 0.3:
|
|
reasons.append("Path traversal patterns detected")
|
|
if features.get("special_char_ratio", 0) > 0.2:
|
|
reasons.append("High special character ratio")
|
|
if features.get("entropy", 0) > 0.7:
|
|
reasons.append("High entropy (possible encoding/obfuscation)")
|
|
|
|
if not reasons:
|
|
reasons.append(f"General pattern matching suggests {label}")
|
|
|
|
return "; ".join(reasons)
|
|
|
|
|
|
class AnomalyDetector:
|
|
"""
|
|
Detect anomalous requests based on baseline behavior.
|
|
Uses statistical methods (z-score, IQR) without requiring ML libraries.
|
|
"""
|
|
|
|
def __init__(self):
|
|
self.feature_extractor = FeatureExtractor()
|
|
self._baseline_stats: Dict[str, Dict[str, float]] = {}
|
|
self._observations: List[Dict[str, float]] = []
|
|
|
|
def add_observation(self, text: str) -> None:
|
|
"""Add an observation to the baseline."""
|
|
features = self.feature_extractor.extract(text)
|
|
self._observations.append(features)
|
|
|
|
# Recalculate baseline after enough observations
|
|
if len(self._observations) >= 10:
|
|
self._update_baseline()
|
|
|
|
def _update_baseline(self) -> None:
|
|
"""Update baseline statistics."""
|
|
if not self._observations:
|
|
return
|
|
|
|
all_features = set()
|
|
for obs in self._observations:
|
|
all_features.update(obs.keys())
|
|
|
|
for feature in all_features:
|
|
values = [obs.get(feature, 0.0) for obs in self._observations]
|
|
mean = sum(values) / len(values)
|
|
var = sum((v - mean) ** 2 for v in values) / len(values)
|
|
std = math.sqrt(var) if var > 0 else 0.001
|
|
|
|
self._baseline_stats[feature] = {
|
|
"mean": mean,
|
|
"std": std,
|
|
"min": min(values),
|
|
"max": max(values),
|
|
}
|
|
|
|
def score(self, text: str) -> AnomalyScore:
|
|
"""Score how anomalous a request is."""
|
|
features = self.feature_extractor.extract(text)
|
|
|
|
if not self._baseline_stats:
|
|
# No baseline yet, use heuristics
|
|
return self._heuristic_score(features)
|
|
|
|
z_scores: Dict[str, float] = {}
|
|
anomalous_features: List[str] = []
|
|
|
|
for feature, value in features.items():
|
|
if feature in self._baseline_stats:
|
|
stats = self._baseline_stats[feature]
|
|
z = (value - stats["mean"]) / stats["std"]
|
|
z_scores[feature] = abs(z)
|
|
|
|
if abs(z) > 2: # More than 2 std deviations
|
|
anomalous_features.append(f"{feature} (z={z:.2f})")
|
|
|
|
# Overall anomaly score (average of z-scores, normalized)
|
|
if z_scores:
|
|
avg_z = sum(z_scores.values()) / len(z_scores)
|
|
max_z = max(z_scores.values())
|
|
score = min(max_z / 5, 1.0) # Normalize to 0-1
|
|
baseline_deviation = avg_z
|
|
else:
|
|
score = 0.5
|
|
baseline_deviation = 0.0
|
|
|
|
# Generate recommendation
|
|
if score > 0.8:
|
|
recommendation = "BLOCK: Highly anomalous, likely attack"
|
|
elif score > 0.5:
|
|
recommendation = "CHALLENGE: Moderately anomalous, requires verification"
|
|
elif score > 0.3:
|
|
recommendation = "LOG: Slightly unusual, monitor closely"
|
|
else:
|
|
recommendation = "ALLOW: Within normal parameters"
|
|
|
|
return AnomalyScore(
|
|
score=score,
|
|
baseline_deviation=baseline_deviation,
|
|
anomalous_features=anomalous_features,
|
|
recommendation=recommendation
|
|
)
|
|
|
|
def _heuristic_score(self, features: Dict[str, float]) -> AnomalyScore:
|
|
"""Score based on heuristics when no baseline exists."""
|
|
score = 0.0
|
|
anomalous_features: List[str] = []
|
|
|
|
# Check for attack indicators
|
|
for attack_type in ["sqli_score", "xss_score", "rce_score", "path_traversal_score"]:
|
|
if features.get(attack_type, 0) > 0.3:
|
|
score += 0.25
|
|
anomalous_features.append(attack_type)
|
|
|
|
# Check for suspicious characteristics
|
|
if features.get("special_char_ratio", 0) > 0.15:
|
|
score += 0.15
|
|
anomalous_features.append("high_special_chars")
|
|
|
|
if features.get("entropy", 0) > 0.8:
|
|
score += 0.1
|
|
anomalous_features.append("high_entropy")
|
|
|
|
score = min(score, 1.0)
|
|
|
|
if score > 0.7:
|
|
recommendation = "BLOCK: Multiple attack indicators"
|
|
elif score > 0.4:
|
|
recommendation = "CHALLENGE: Suspicious characteristics"
|
|
else:
|
|
recommendation = "ALLOW: No obvious threats"
|
|
|
|
return AnomalyScore(
|
|
score=score,
|
|
baseline_deviation=0.0,
|
|
anomalous_features=anomalous_features,
|
|
recommendation=recommendation
|
|
)
|
|
|
|
|
|
class ThreatClassifier:
|
|
"""
|
|
High-level threat classifier combining multiple techniques.
|
|
|
|
Usage:
|
|
classifier = ThreatClassifier()
|
|
result = classifier.classify("' OR '1'='1")
|
|
print(f"Label: {result.label}, Confidence: {result.confidence}")
|
|
"""
|
|
|
|
def __init__(self, model_path: Optional[Path] = None):
|
|
self.naive_bayes = NaiveBayesClassifier()
|
|
self.anomaly_detector = AnomalyDetector()
|
|
self.model_path = model_path
|
|
|
|
# Train on startup
|
|
self.naive_bayes.train()
|
|
|
|
def classify(self, text: str) -> ClassificationResult:
|
|
"""Classify a request/pattern."""
|
|
return self.naive_bayes.classify(text)
|
|
|
|
def score_anomaly(self, text: str) -> AnomalyScore:
|
|
"""Score how anomalous a request is."""
|
|
return self.anomaly_detector.score(text)
|
|
|
|
def analyze(self, text: str) -> Dict[str, Any]:
|
|
"""Full analysis combining classification and anomaly detection."""
|
|
classification = self.classify(text)
|
|
anomaly = self.score_anomaly(text)
|
|
|
|
return {
|
|
"classification": {
|
|
"label": classification.label,
|
|
"confidence": classification.confidence,
|
|
"probabilities": classification.probabilities,
|
|
"explanation": classification.explanation,
|
|
},
|
|
"anomaly": {
|
|
"score": anomaly.score,
|
|
"baseline_deviation": anomaly.baseline_deviation,
|
|
"anomalous_features": anomaly.anomalous_features,
|
|
"recommendation": anomaly.recommendation,
|
|
},
|
|
"risk_level": self._compute_risk_level(classification, anomaly),
|
|
}
|
|
|
|
def _compute_risk_level(
|
|
self,
|
|
classification: ClassificationResult,
|
|
anomaly: AnomalyScore
|
|
) -> str:
|
|
"""Compute overall risk level."""
|
|
# High-risk attack types
|
|
high_risk_labels = {"sqli", "xss", "rce"}
|
|
|
|
if classification.label in high_risk_labels and classification.confidence > 0.7:
|
|
return "critical"
|
|
|
|
if classification.label in high_risk_labels and classification.confidence > 0.4:
|
|
return "high"
|
|
|
|
if anomaly.score > 0.7:
|
|
return "high"
|
|
|
|
if classification.label == "scanner":
|
|
return "medium"
|
|
|
|
if anomaly.score > 0.4:
|
|
return "medium"
|
|
|
|
return "low"
|
|
|
|
|
|
# CLI for testing
|
|
if __name__ == "__main__":
|
|
import sys
|
|
|
|
classifier = ThreatClassifier()
|
|
|
|
test_inputs = [
|
|
"' OR '1'='1",
|
|
"<script>alert('xss')</script>",
|
|
"; cat /etc/passwd",
|
|
"../../../etc/passwd",
|
|
"Mozilla/5.0 (Windows NT 10.0)",
|
|
"/api/users/123",
|
|
]
|
|
|
|
if len(sys.argv) > 1:
|
|
test_inputs = sys.argv[1:]
|
|
|
|
print("\n🤖 ML Threat Classifier Test")
|
|
print("=" * 60)
|
|
|
|
for text in test_inputs:
|
|
routing_action, shadow = layer0_entry(text)
|
|
if routing_action != "HANDOFF_TO_LAYER1":
|
|
print(_layer0_cli_msg(routing_action, shadow), file=sys.stderr)
|
|
continue
|
|
|
|
result = classifier.analyze(text)
|
|
print(f"\nInput: {text[:50]}...")
|
|
print(f" Label: {result['classification']['label']}")
|
|
print(f" Confidence: {result['classification']['confidence']:.2%}")
|
|
print(f" Risk Level: {result['risk_level'].upper()}")
|
|
print(f" Anomaly Score: {result['anomaly']['score']:.2%}")
|
|
print(f" Recommendation: {result['anomaly']['recommendation']}")
|
|
|
|
|
|
def _layer0_cli_msg(routing_action: str, shadow: ShadowEvalResult) -> str:
|
|
if routing_action == "FAIL_CLOSED":
|
|
return "Layer 0: cannot comply with this request."
|
|
if routing_action == "HANDOFF_TO_GUARDRAILS":
|
|
reason = shadow.reason or "governance_violation"
|
|
return f"Layer 0: governance violation detected ({reason})."
|
|
if routing_action == "PROMPT_FOR_CLARIFICATION":
|
|
return "Layer 0: request is ambiguous. Please add specifics before rerunning."
|
|
return "Layer 0: unrecognized routing action; refusing request."
|