from __future__ import annotations import glob import json import os import sys from dataclasses import asdict from pathlib import Path from typing import Any, Callable, Dict, List, Optional from cloudflare.layer0 import layer0_entry from cloudflare.layer0.shadow_classifier import ShadowEvalResult from .orchestrator import ThreatAssessment, WAFInsight, WAFIntelligence MAX_BYTES_DEFAULT = 32_000 def _cloudflare_root() -> Path: # mcp_server.py -> waf_intelligence -> mcp -> cloudflare return Path(__file__).resolve().parents[2] def _max_bytes() -> int: raw = (os.getenv("VM_MCP_MAX_BYTES") or "").strip() if not raw: return MAX_BYTES_DEFAULT try: return max(4_096, int(raw)) except ValueError: return MAX_BYTES_DEFAULT def _redact(obj: Any) -> Any: sensitive_keys = ("token", "secret", "password", "private", "key", "certificate") if isinstance(obj, dict): out: Dict[str, Any] = {} for k, v in obj.items(): if any(s in str(k).lower() for s in sensitive_keys): out[k] = "" else: out[k] = _redact(v) return out if isinstance(obj, list): return [_redact(v) for v in obj] if isinstance(obj, str): if obj.startswith("ghp_") or obj.startswith("github_pat_"): return "" return obj return obj def _safe_json(payload: Dict[str, Any]) -> str: payload = _redact(payload) raw = json.dumps(payload, ensure_ascii=False, separators=(",", ":"), default=str) if len(raw.encode("utf-8")) <= _max_bytes(): return json.dumps(payload, ensure_ascii=False, indent=2, default=str) truncated = { "ok": payload.get("ok", True), "truncated": True, "summary": payload.get("summary", "Response exceeded max size; truncated."), "next_steps": payload.get( "next_steps", [ "request fewer files/insights (limit=...)", "use higher min_severity to reduce output", ], ), } return json.dumps(truncated, ensure_ascii=False, indent=2, default=str) def _mcp_text_result( payload: Dict[str, Any], *, is_error: bool = False ) -> Dict[str, Any]: result: Dict[str, Any] = { "content": [{"type": "text", "text": _safe_json(payload)}] } if is_error: result["isError"] = True return result def _insight_to_dict(insight: WAFInsight) -> Dict[str, Any]: return asdict(insight) def _assessment_to_dict(assessment: ThreatAssessment) -> Dict[str, Any]: violations = [] if assessment.analysis_result and getattr( assessment.analysis_result, "violations", None ): violations = list(assessment.analysis_result.violations) severity_counts = {"error": 0, "warning": 0, "info": 0} for v in violations: sev = getattr(v, "severity", "info") if sev in severity_counts: severity_counts[sev] += 1 return { "risk_score": assessment.risk_score, "risk_level": assessment.risk_level, "classification_summary": assessment.classification_summary, "recommended_actions": assessment.recommended_actions, "analysis": { "has_config_analysis": assessment.analysis_result is not None, "violations_total": len(violations), "violations_by_severity": severity_counts, }, "has_threat_intel": assessment.threat_report is not None, "generated_at": str(assessment.generated_at), } TOOLS: List[Dict[str, Any]] = [ { "name": "waf_capabilities", "description": "List available WAF Intelligence capabilities.", "inputSchema": {"type": "object", "properties": {}}, }, { "name": "analyze_waf", "description": "Analyze Terraform WAF file(s) and return curated insights (legacy alias for waf_analyze).", "inputSchema": { "type": "object", "properties": { "file": { "type": "string", "description": "Single file path to analyze.", }, "files": { "type": "array", "items": {"type": "string"}, "description": "List of file paths or glob patterns to analyze.", }, "limit": { "type": "integer", "default": 3, "description": "Max insights per file.", }, "severity_threshold": { "type": "string", "enum": ["info", "warning", "error"], "default": "warning", "description": "Minimum severity to include (alias for min_severity).", }, }, }, }, { "name": "waf_analyze", "description": "Analyze Terraform WAF file(s) and return curated insights (requires file or files).", "inputSchema": { "type": "object", "properties": { "file": { "type": "string", "description": "Single file path to analyze.", }, "files": { "type": "array", "items": {"type": "string"}, "description": "List of file paths or glob patterns to analyze.", }, "limit": { "type": "integer", "default": 3, "description": "Max insights per file.", }, "min_severity": { "type": "string", "enum": ["info", "warning", "error"], "default": "warning", "description": "Minimum severity to include.", }, }, }, }, { "name": "waf_assess", "description": "Run a broader assessment (optionally includes threat intel collection).", "inputSchema": { "type": "object", "properties": { "waf_config_path": { "type": "string", "description": "Path to Terraform WAF config (default: terraform/waf.tf).", }, "include_threat_intel": { "type": "boolean", "default": False, "description": "If true, attempt to collect threat intel (may require network and credentials).", }, }, }, }, { "name": "waf_generate_gitops_proposals", "description": "Generate GitOps-ready rule proposals (best-effort; requires threat intel to produce output).", "inputSchema": { "type": "object", "properties": { "waf_config_path": { "type": "string", "description": "Path to Terraform WAF config (default: terraform/waf.tf).", }, "include_threat_intel": { "type": "boolean", "default": True, "description": "Attempt to collect threat intel before proposing rules.", }, "max_proposals": { "type": "integer", "default": 5, "description": "Maximum proposals to generate.", }, }, }, }, ] class WafIntelligenceTools: def __init__(self) -> None: self.workspace_root = _cloudflare_root() self.repo_root = self.workspace_root.parent self.waf = WAFIntelligence(workspace_path=str(self.workspace_root)) def _resolve_path(self, raw: str) -> Path: path = Path(raw) if path.is_absolute(): return path candidates = [ Path.cwd() / path, self.workspace_root / path, self.repo_root / path, ] for candidate in candidates: if candidate.exists(): return candidate return self.workspace_root / path def waf_capabilities(self) -> Dict[str, Any]: return { "ok": True, "summary": "WAF Intelligence capabilities.", "data": {"capabilities": self.waf.capabilities}, "truncated": False, "next_steps": [ "Call waf_analyze(file=..., limit=...) to analyze config.", "Call waf_assess(include_threat_intel=true) for a broader assessment.", ], } def waf_analyze( self, *, file: Optional[str] = None, files: Optional[List[str]] = None, limit: int = 3, min_severity: str = "warning", ) -> Dict[str, Any]: paths: List[str] = [] if files: for pattern in files: paths.extend(glob.glob(pattern)) if file: paths.append(file) seen = set() unique_paths: List[str] = [] for p in paths: if p not in seen: seen.add(p) unique_paths.append(p) if not unique_paths: return { "ok": False, "summary": "Provide 'file' or 'files' to analyze.", "truncated": False, "next_steps": ["Call waf_analyze(file='terraform/waf.tf')"], } results: List[Dict[str, Any]] = [] for p in unique_paths: path = self._resolve_path(p) if not path.exists(): results.append( { "file": str(path), "ok": False, "summary": "File not found.", } ) continue insights = self.waf.analyze_and_recommend( str(path), limit=limit, min_severity=min_severity, ) results.append( { "file": str(path), "ok": True, "insights": [_insight_to_dict(i) for i in insights], } ) ok = all(r.get("ok") for r in results) return { "ok": ok, "summary": f"Analyzed {len(results)} file(s).", "data": {"results": results}, "truncated": False, "next_steps": [ "Raise/lower min_severity or limit to tune output size.", ], } def waf_assess( self, *, waf_config_path: Optional[str] = None, include_threat_intel: bool = False, ) -> Dict[str, Any]: waf_config_path_resolved = ( str(self._resolve_path(waf_config_path)) if waf_config_path else None ) assessment = self.waf.full_assessment( waf_config_path=waf_config_path_resolved, include_threat_intel=include_threat_intel, ) return { "ok": True, "summary": "WAF assessment complete.", "data": _assessment_to_dict(assessment), "truncated": False, "next_steps": [ "Call waf_generate_gitops_proposals(...) to draft Terraform rule proposals (best-effort).", ], } def waf_generate_gitops_proposals( self, *, waf_config_path: Optional[str] = None, include_threat_intel: bool = True, max_proposals: int = 5, ) -> Dict[str, Any]: waf_config_path_resolved = ( str(self._resolve_path(waf_config_path)) if waf_config_path else None ) assessment = self.waf.full_assessment( waf_config_path=waf_config_path_resolved, include_threat_intel=include_threat_intel, ) proposals = self.waf.generate_gitops_proposals( threat_report=assessment.threat_report, max_proposals=max_proposals, ) return { "ok": True, "summary": f"Generated {len(proposals)} proposal(s).", "data": { "assessment": _assessment_to_dict(assessment), "proposals": proposals, }, "truncated": False, "next_steps": [ "If proposals are empty, enable threat intel and ensure required credentials/log sources exist.", ], } class StdioJsonRpc: def __init__(self) -> None: self._in = sys.stdin.buffer self._out = sys.stdout.buffer self._mode: str | None = None # "headers" | "line" def read_message(self) -> Optional[Dict[str, Any]]: while True: if self._mode == "line": line = self._in.readline() if not line: return None raw = line.decode("utf-8", "replace").strip() if not raw: continue try: msg = json.loads(raw) except Exception: continue if isinstance(msg, dict): return msg continue first = self._in.readline() if not first: return None if first in (b"\r\n", b"\n"): continue # Auto-detect newline-delimited JSON framing. if self._mode is None and first.lstrip().startswith(b"{"): try: msg = json.loads(first.decode("utf-8", "replace")) except Exception: msg = None if isinstance(msg, dict): self._mode = "line" return msg headers: Dict[str, str] = {} try: text = first.decode("utf-8", "replace").strip() except Exception: continue if ":" not in text: continue k, v = text.split(":", 1) headers[k.lower().strip()] = v.strip() while True: line = self._in.readline() if not line: return None if line in (b"\r\n", b"\n"): break try: text = line.decode("utf-8", "replace").strip() except Exception: continue if ":" not in text: continue k, v = text.split(":", 1) headers[k.lower().strip()] = v.strip() if "content-length" not in headers: return None try: length = int(headers["content-length"]) except ValueError: return None body = self._in.read(length) if not body: return None self._mode = "headers" msg = json.loads(body.decode("utf-8", "replace")) if isinstance(msg, dict): return msg return None def write_message(self, message: Dict[str, Any]) -> None: if self._mode == "line": payload = json.dumps( message, ensure_ascii=False, separators=(",", ":"), default=str ).encode("utf-8") self._out.write(payload + b"\n") self._out.flush() return body = json.dumps( message, ensure_ascii=False, separators=(",", ":"), default=str ).encode("utf-8") header = f"Content-Length: {len(body)}\r\n\r\n".encode("utf-8") self._out.write(header) self._out.write(body) self._out.flush() def main() -> None: tools = WafIntelligenceTools() rpc = StdioJsonRpc() handlers: Dict[str, Callable[[Dict[str, Any]], Dict[str, Any]]] = { "waf_capabilities": lambda a: tools.waf_capabilities(), "analyze_waf": lambda a: tools.waf_analyze( file=a.get("file"), files=a.get("files"), limit=int(a.get("limit", 3)), min_severity=str(a.get("severity_threshold", "warning")), ), "waf_analyze": lambda a: tools.waf_analyze(**a), "waf_assess": lambda a: tools.waf_assess(**a), "waf_generate_gitops_proposals": lambda a: tools.waf_generate_gitops_proposals( **a ), } while True: msg = rpc.read_message() if msg is None: return method = msg.get("method") msg_id = msg.get("id") params = msg.get("params") or {} try: if method == "initialize": result = { "protocolVersion": "2024-11-05", "serverInfo": {"name": "waf_intelligence", "version": "0.1.0"}, "capabilities": {"tools": {}}, } rpc.write_message({"jsonrpc": "2.0", "id": msg_id, "result": result}) continue if method == "tools/list": rpc.write_message( {"jsonrpc": "2.0", "id": msg_id, "result": {"tools": TOOLS}} ) continue if method == "tools/call": tool_name = str(params.get("name") or "") args = params.get("arguments") or {} routing_action, shadow = layer0_entry( _shadow_query_repr(tool_name, args) ) if routing_action != "HANDOFF_TO_LAYER1": rpc.write_message( { "jsonrpc": "2.0", "id": msg_id, "result": _mcp_text_result( _layer0_payload(routing_action, shadow), is_error=True ), } ) continue handler = handlers.get(tool_name) if not handler: rpc.write_message( { "jsonrpc": "2.0", "id": msg_id, "result": _mcp_text_result( { "ok": False, "summary": f"Unknown tool: {tool_name}", "data": {"known_tools": sorted(handlers.keys())}, "truncated": False, "next_steps": ["Call tools/list"], }, is_error=True, ), } ) continue payload = handler(args) is_error = ( not bool(payload.get("ok", True)) if isinstance(payload, dict) else False ) rpc.write_message( { "jsonrpc": "2.0", "id": msg_id, "result": _mcp_text_result(payload, is_error=is_error), } ) continue # Ignore notifications. if msg_id is None: continue rpc.write_message( { "jsonrpc": "2.0", "id": msg_id, "result": _mcp_text_result( {"ok": False, "summary": f"Unsupported method: {method}"}, is_error=True, ), } ) except Exception as e: # noqa: BLE001 if msg_id is not None: rpc.write_message( { "jsonrpc": "2.0", "id": msg_id, "result": _mcp_text_result( {"ok": False, "summary": f"fatal error: {e}"}, is_error=True, ), } ) def _shadow_query_repr(tool_name: str, tool_args: Dict[str, Any]) -> str: if tool_name == "waf_capabilities": return "List WAF Intelligence capabilities." try: return f"{tool_name}: {json.dumps(tool_args, sort_keys=True, default=str)}" except Exception: return f"{tool_name}: {str(tool_args)}" def _layer0_payload(routing_action: str, shadow: ShadowEvalResult) -> Dict[str, Any]: if routing_action == "FAIL_CLOSED": return {"ok": False, "summary": "Layer 0: cannot comply with this request."} if routing_action == "HANDOFF_TO_GUARDRAILS": reason = shadow.reason or "governance_violation" return { "ok": False, "summary": f"Layer 0: governance violation detected ({reason}).", } if routing_action == "PROMPT_FOR_CLARIFICATION": return { "ok": False, "summary": "Layer 0: request is ambiguous. Please clarify and retry.", } return {"ok": False, "summary": "Layer 0: unrecognized routing action; refusing."} if __name__ == "__main__": main()