from __future__ import annotations import json import os import re from dataclasses import dataclass from datetime import datetime, timezone from pathlib import Path from typing import Any, Iterable, Sequence THIS_FILE = Path(__file__).resolve() LAYER0_DIR = THIS_FILE.parent REPO_ROOT = LAYER0_DIR.parent.parent _RE_URL = re.compile(r"\bhttps?://\S+\b", re.IGNORECASE) _RE_EMAIL = re.compile(r"\b[A-Z0-9._%+-]+@[A-Z0-9.-]+\.[A-Z]{2,}\b", re.IGNORECASE) _RE_IPV4 = re.compile(r"\b(?:\d{1,3}\.){3}\d{1,3}\b") _RE_IPV6 = re.compile(r"\b(?:[0-9a-f]{0,4}:){2,}[0-9a-f]{0,4}\b", re.IGNORECASE) _RE_UUID = re.compile( r"\b[0-9a-f]{8}-[0-9a-f]{4}-[1-5][0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}\b", re.IGNORECASE, ) _RE_HEX_LONG = re.compile(r"\b[0-9a-f]{32,}\b", re.IGNORECASE) _RE_BASE64ISH = re.compile(r"\b[A-Za-z0-9+/]{28,}={0,2}\b") _RE_PATHISH = re.compile(r"(?:(?:\.\.?/)|/)[A-Za-z0-9._~/-]{2,}") _RE_NUMBER = re.compile(r"\b\d+\b") _RE_TOKEN = re.compile(r"[a-z][a-z_-]{1,31}", re.IGNORECASE) SAFE_VOCAB = { # Governance / safety verbs "disable", "override", "bypass", "skip", "ignore", "evade", "break", "force", "apply", "deploy", "destroy", "delete", "drop", "remove", "exfiltrate", # Critical nouns / domains "guardrails", "permissions", "governance", "git", "gitops", "dashboard", "manual", "prod", "production", "staging", "terraform", "waf", "dns", "tunnel", "access", "token", "secret", "key", "credential", "admin", "root", # Phrases often seen in L0 rules (tokenized) "self", "modifying", "directly", } def _utc_now_iso_z() -> str: return ( datetime.now(timezone.utc) .replace(microsecond=0) .isoformat() .replace("+00:00", "Z") ) def normalize_query_for_matching(query: str) -> str: """ Produce a low-leakage normalized string suitable for storing and matching. Invariants: - Never stores raw URLs, IPs, emails, long hex strings, base64ish blobs, UUIDs, or paths. - Numbers are stripped to . - Only safe vocabulary tokens are preserved; other words are dropped. """ q = (query or "").lower().strip() if not q: return "" # Keep placeholders lowercase to make matching stable across sources. q = _RE_URL.sub("", q) q = _RE_EMAIL.sub("", q) q = _RE_IPV4.sub("", q) q = _RE_IPV6.sub("", q) q = _RE_UUID.sub("", q) q = _RE_PATHISH.sub("", q) q = _RE_HEX_LONG.sub("", q) q = _RE_BASE64ISH.sub("", q) q = _RE_NUMBER.sub("", q) # Tokenize; keep placeholders and a tight safe vocabulary. tokens: list[str] = [] for raw in re.split(r"[^a-z0-9_<>\-_/]+", q): t = raw.strip() if not t: continue if t.startswith("<") and t.endswith(">"): tokens.append(t) continue if _RE_TOKEN.fullmatch(t) and t in SAFE_VOCAB: tokens.append(t) # De-dupe while preserving order. seen: set[str] = set() out: list[str] = [] for t in tokens: if t in seen: continue seen.add(t) out.append(t) return " ".join(out) def normalized_tokens(query: str) -> list[str]: s = normalize_query_for_matching(query) return s.split() if s else [] @dataclass(frozen=True) class LearnedPattern: pattern_id: str tokens_all: tuple[str, ...] classification: str reason: str | None risk_score: int flags: tuple[str, ...] specificity_score: int min_support: int last_seen: str | None source: dict[str, Any] | None mode: str # "escalate" | "relax" def matches(self, normalized_query: str) -> bool: if not normalized_query: return False hay = set(normalized_query.split()) return all(t in hay for t in self.tokens_all) def _default_active_path() -> Path: configured = os.environ.get("LAYER0_ACTIVE_PATTERNS_PATH") if configured: return Path(configured).expanduser().resolve() return (REPO_ROOT / ".state" / "layer0_patterns_active.json").resolve() class PatternStore: """ Read-only active pattern snapshot. This is intentionally immutable during request handling; mutations happen in offline jobs (learn/replay) that write a new snapshot and log an artifact. """ def __init__(self, active_path: Path | None = None): self._active_path = active_path or _default_active_path() self._active: list[LearnedPattern] = [] self._loaded = False @property def active_path(self) -> Path: return self._active_path def load(self) -> None: if self._loaded: return self._loaded = True self._active = self._load_patterns_file(self._active_path) def patterns(self) -> list[LearnedPattern]: self.load() return list(self._active) def match_ordered(self, normalized_query: str) -> list[LearnedPattern]: self.load() matched = [p for p in self._active if p.matches(normalized_query)] severity_rank = { "blessed": 0, "ambiguous": 1, "forbidden": 2, "catastrophic": 3, } matched.sort( key=lambda p: ( severity_rank.get(p.classification, 0), p.specificity_score, p.min_support, p.last_seen or "", ), reverse=True, ) return matched @staticmethod def _load_patterns_file(path: Path) -> list[LearnedPattern]: if not path.exists(): return [] data = json.loads(path.read_text(encoding="utf-8")) items = data.get("patterns") if isinstance(data, dict) else data if not isinstance(items, list): return [] patterns: list[LearnedPattern] = [] for item in items: if not isinstance(item, dict): continue tokens = item.get("tokens_all") or item.get("tokens") or [] if not isinstance(tokens, list) or not tokens: continue tokens_norm = tuple( t.lower() if isinstance(t, str) else "" for t in tokens if isinstance(t, str) and t and (t.startswith("<") or t.lower() in SAFE_VOCAB) ) if not tokens_norm: continue classification = item.get("classification") if classification not in { "blessed", "ambiguous", "forbidden", "catastrophic", }: continue flags = item.get("flags") or [] if not isinstance(flags, list): flags = [] mode = item.get("mode") or "escalate" if mode not in {"escalate", "relax"}: mode = "escalate" min_support = int(item.get("min_support") or item.get("support") or 0) specificity = int(item.get("specificity_score") or len(tokens_norm)) risk_score = int(item.get("risk_score") or 0) patterns.append( LearnedPattern( pattern_id=str(item.get("pattern_id") or item.get("id") or ""), tokens_all=tokens_norm, classification=classification, reason=item.get("reason"), risk_score=risk_score, flags=tuple(str(f) for f in flags if isinstance(f, str)), specificity_score=specificity, min_support=min_support, last_seen=item.get("last_seen"), source=item.get("source") if isinstance(item.get("source"), dict) else None, mode=mode, ) ) severity_rank = { "blessed": 0, "ambiguous": 1, "forbidden": 2, "catastrophic": 3, } patterns.sort( key=lambda p: ( severity_rank.get(p.classification, 0), p.specificity_score, p.min_support, p.last_seen or "", ), reverse=True, ) return patterns def pattern_dict( *, tokens_all: Sequence[str], classification: str, reason: str | None, risk_score: int, flags: Sequence[str], min_support: int, last_seen: str | None = None, source: dict[str, Any] | None = None, mode: str = "escalate", pattern_id: str | None = None, ) -> dict[str, Any]: tokens = [t for t in tokens_all if isinstance(t, str) and t] return { "pattern_id": pattern_id or "", "tokens_all": tokens, "classification": classification, "reason": reason, "risk_score": int(risk_score), "flags": list(flags), "specificity_score": int(len(tokens)), "min_support": int(min_support), "last_seen": last_seen or _utc_now_iso_z(), "source": source or {}, "mode": mode, } def write_pattern_snapshot(path: Path, patterns: Iterable[dict[str, Any]]) -> None: path.parent.mkdir(parents=True, exist_ok=True) payload = {"generated_at": _utc_now_iso_z(), "patterns": list(patterns)} path.write_text( json.dumps(payload, ensure_ascii=False, sort_keys=True, indent=2) + "\n", encoding="utf-8", )