chore: pre-migration snapshot
Some checks failed
WAF Intelligence Guardrail / waf-intel (push) Waiting to run
Cloudflare Registry Validation / validate-registry (push) Has been cancelled

Layer0, MCP servers, Terraform consolidation
This commit is contained in:
Vault Sovereign
2025-12-27 01:52:27 +00:00
parent 7f2e60e1c5
commit f0b8d962de
67 changed files with 14887 additions and 650 deletions

View File

@@ -3,4 +3,6 @@ MCP tools for the CLOUDFLARE workspace.
Currently:
- oracle_answer: compliance / security oracle
- cloudflare_safe: summary-first Cloudflare state + tunnel helpers
- akash_docs: Akash docs fetch/search + SDL template helper
"""

View File

@@ -0,0 +1,10 @@
"""
Akash docs + deployment helpers exposed as an MCP server.
Tools:
- akash_docs_list_routes: discover common docs routes from akash.network
- akash_docs_fetch: fetch a docs page (prefers GitHub markdown, falls back to site HTML)
- akash_docs_search: keyword search across discovered routes (cached)
- akash_sdl_snippet: generate a minimal Akash SDL template
"""

View File

@@ -0,0 +1,7 @@
from __future__ import annotations
from .server import main
if __name__ == "__main__":
main()

861
mcp/akash_docs/server.py Normal file
View File

@@ -0,0 +1,861 @@
from __future__ import annotations
import hashlib
import json
import os
import re
import sys
import urllib.error
import urllib.parse
import urllib.request
from dataclasses import dataclass
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple
AKASH_SITE_BASE = "https://akash.network"
AKASH_DOCS_BASE = f"{AKASH_SITE_BASE}/docs"
AKASH_DOCS_GITHUB_OWNER = "akash-network"
AKASH_DOCS_GITHUB_REPO = "website-revamp"
AKASH_DOCS_GITHUB_REF_DEFAULT = "main"
AKASH_DOCS_GITHUB_DOCS_ROOT = "src/content/Docs"
MAX_BYTES_DEFAULT = 32_000
def _repo_root() -> Path:
# server.py -> akash_docs -> mcp -> cloudflare -> <repo root>
return Path(__file__).resolve().parents[3]
def _utc_now_iso() -> str:
return datetime.now(timezone.utc).isoformat()
def _max_bytes() -> int:
raw = (os.getenv("VM_MCP_MAX_BYTES") or "").strip()
if not raw:
return MAX_BYTES_DEFAULT
try:
return max(4_096, int(raw))
except ValueError:
return MAX_BYTES_DEFAULT
def _sha256_hex(text: str) -> str:
return hashlib.sha256(text.encode("utf-8")).hexdigest()
def _http_get(url: str, *, timeout: int = 30) -> str:
req = urllib.request.Request(
url=url,
headers={
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8",
"User-Agent": "work-core-akamcp/0.1 (+https://akash.network)",
},
)
with urllib.request.urlopen(req, timeout=timeout) as resp:
return resp.read().decode("utf-8", "replace")
def _normalize_route(route_or_url: str) -> Tuple[str, str]:
"""
Returns (route, canonical_url).
route: "getting-started/what-is-akash" (no leading/trailing slashes)
canonical_url: https://akash.network/docs/<route>
"""
raw = (route_or_url or "").strip()
if not raw:
return "", AKASH_DOCS_BASE + "/"
if raw.startswith("http://") or raw.startswith("https://"):
parsed = urllib.parse.urlparse(raw)
path = parsed.path or ""
# Normalize to docs route if possible.
if path in ("/docs", "/docs/"):
return "", AKASH_DOCS_BASE + "/"
if path.startswith("/docs/"):
route = path[len("/docs/") :].strip("/")
return route, f"{AKASH_DOCS_BASE}/{route}"
return path.strip("/"), raw
# Accept "/docs/..." or "docs/..."
route = raw.lstrip("/")
if route in ("docs", "docs/"):
return "", AKASH_DOCS_BASE + "/"
if route.startswith("docs/"):
route = route[len("docs/") :]
route = route.strip("/")
return route, f"{AKASH_DOCS_BASE}/{route}" if route else AKASH_DOCS_BASE + "/"
def _strip_frontmatter(markdown: str) -> str:
# Remove leading YAML frontmatter: ---\n...\n---\n
if not markdown.startswith("---"):
return markdown
m = re.match(r"^---\s*\n.*?\n---\s*\n", markdown, flags=re.S)
if not m:
return markdown
return markdown[m.end() :]
def _github_candidates(route: str) -> List[str]:
base = f"{AKASH_DOCS_GITHUB_DOCS_ROOT}/{route}".rstrip("/")
candidates = [
f"{base}/index.md",
f"{base}/index.mdx",
f"{base}.md",
f"{base}.mdx",
]
# Handle root docs landing if route is empty.
if not route:
candidates = [
f"{AKASH_DOCS_GITHUB_DOCS_ROOT}/index.md",
f"{AKASH_DOCS_GITHUB_DOCS_ROOT}/index.mdx",
]
return candidates
def _fetch_markdown_from_github(route: str, *, ref: str) -> Tuple[str, str, str]:
"""
Returns (markdown, raw_url, repo_path) or raises urllib.error.HTTPError.
"""
last_err: Optional[urllib.error.HTTPError] = None
for repo_path in _github_candidates(route):
raw_url = (
f"https://raw.githubusercontent.com/{AKASH_DOCS_GITHUB_OWNER}/"
f"{AKASH_DOCS_GITHUB_REPO}/{ref}/{repo_path}"
)
try:
return _http_get(raw_url), raw_url, repo_path
except urllib.error.HTTPError as e:
if e.code == 404:
last_err = e
continue
raise
if last_err:
raise last_err
raise urllib.error.HTTPError(
url="",
code=404,
msg="Not Found",
hdrs=None,
fp=None,
)
def _extract_article_html(page_html: str) -> str:
m = re.search(r"<article\b[^>]*>(.*?)</article>", page_html, flags=re.S | re.I)
if m:
return m.group(1)
m = re.search(r"<main\b[^>]*>(.*?)</main>", page_html, flags=re.S | re.I)
if m:
return m.group(1)
return page_html
def _html_to_text(article_html: str) -> str:
# Drop scripts/styles
cleaned = re.sub(
r"<(script|style)\b[^>]*>.*?</\1>", "", article_html, flags=re.S | re.I
)
# Preserve code blocks a bit better (Astro uses <div class="ec-line"> for each line)
def _pre_repl(match: re.Match[str]) -> str:
pre = match.group(0)
pre = re.sub(r"</div>\s*", "\n", pre, flags=re.I)
pre = re.sub(r"<div\b[^>]*>", "", pre, flags=re.I)
pre = re.sub(r"<br\s*/?>", "\n", pre, flags=re.I)
pre = re.sub(r"<[^>]+>", "", pre)
return "\n```\n" + _html_unescape(pre).strip() + "\n```\n"
cleaned = re.sub(r"<pre\b[^>]*>.*?</pre>", _pre_repl, cleaned, flags=re.S | re.I)
# Newlines for common block tags
cleaned = re.sub(
r"</(p|h1|h2|h3|h4|h5|h6|li|blockquote)>", "\n", cleaned, flags=re.I
)
cleaned = re.sub(r"<br\s*/?>", "\n", cleaned, flags=re.I)
cleaned = re.sub(r"<hr\b[^>]*>", "\n---\n", cleaned, flags=re.I)
# Strip remaining tags
cleaned = re.sub(r"<[^>]+>", "", cleaned)
text = _html_unescape(cleaned)
lines = [ln.rstrip() for ln in text.splitlines()]
# Collapse excessive blank lines
out: List[str] = []
blank = False
for ln in lines:
if ln.strip() == "":
if blank:
continue
blank = True
out.append("")
continue
blank = False
out.append(ln.strip())
return "\n".join(out).strip()
def _html_unescape(text: str) -> str:
# Avoid importing html module repeatedly; do it lazily.
import html as _html # local import to keep global import list small
return _html.unescape(text)
def _discover_routes_from_docs_index() -> List[str]:
html = _http_get(AKASH_DOCS_BASE + "/")
hrefs = set(re.findall(r'href=\"(/docs/[^\"#?]+)\"', html))
routes: List[str] = []
for href in sorted(hrefs):
route, _url = _normalize_route(href)
if route:
routes.append(route)
return routes
@dataclass(frozen=True)
class CachedDoc:
cache_key: str
fetched_at: str
source: str
route: str
url: str
ref: str
content_path: str
class DocStore:
def __init__(self, root_dir: Path) -> None:
self.root_dir = root_dir
self.pages_dir = root_dir / "pages"
self.index_path = root_dir / "index.json"
self.pages_dir.mkdir(parents=True, exist_ok=True)
self._index: Dict[str, Dict[str, Any]] = {}
if self.index_path.exists():
try:
self._index = json.loads(self.index_path.read_text(encoding="utf-8"))
except Exception:
self._index = {}
def _write_index(self) -> None:
tmp = self.index_path.with_suffix(".tmp")
tmp.write_text(
json.dumps(self._index, ensure_ascii=False, indent=2) + "\n",
encoding="utf-8",
)
tmp.replace(self.index_path)
def get(self, cache_key: str) -> Optional[CachedDoc]:
raw = self._index.get(cache_key)
if not raw:
return None
path = Path(raw.get("content_path") or "")
if not path.exists():
return None
return CachedDoc(
cache_key=cache_key,
fetched_at=str(raw.get("fetched_at") or ""),
source=str(raw.get("source") or ""),
route=str(raw.get("route") or ""),
url=str(raw.get("url") or ""),
ref=str(raw.get("ref") or ""),
content_path=str(path),
)
def save(
self,
*,
cache_key: str,
source: str,
route: str,
url: str,
ref: str,
content: str,
) -> CachedDoc:
content_hash = _sha256_hex(f"{source}:{ref}:{url}")[:20]
path = self.pages_dir / f"{content_hash}.txt"
path.write_text(content, encoding="utf-8")
entry = {
"fetched_at": _utc_now_iso(),
"source": source,
"route": route,
"url": url,
"ref": ref,
"content_path": str(path),
}
self._index[cache_key] = entry
self._write_index()
return self.get(cache_key) or CachedDoc(
cache_key=cache_key,
fetched_at=entry["fetched_at"],
source=source,
route=route,
url=url,
ref=ref,
content_path=str(path),
)
def _default_state_dir() -> Path:
return _repo_root() / "archive_runtime" / "akash_docs_mcp"
def _truncate_to_max_bytes(text: str, *, max_bytes: int) -> Tuple[str, bool]:
blob = text.encode("utf-8")
if len(blob) <= max_bytes:
return text, False
# Reserve a bit for the truncation notice
reserve = min(512, max_bytes // 10)
head = blob[: max(0, max_bytes - reserve)].decode("utf-8", "replace")
head = head.rstrip() + "\n\n[TRUNCATED: response exceeded VM_MCP_MAX_BYTES]\n"
return head, True
def _mcp_text_result(text: str, *, is_error: bool = False) -> Dict[str, Any]:
text, _truncated = _truncate_to_max_bytes(text, max_bytes=_max_bytes())
result: Dict[str, Any] = {"content": [{"type": "text", "text": text}]}
if is_error:
result["isError"] = True
return result
class AkashDocsTools:
def __init__(self) -> None:
state_dir = Path(os.getenv("VM_AKASH_DOCS_MCP_STATE_DIR") or _default_state_dir())
self.store = DocStore(state_dir)
def akash_docs_list_routes(self) -> Dict[str, Any]:
routes = _discover_routes_from_docs_index()
return {
"ok": True,
"summary": f"Discovered {len(routes)} docs route(s) from {AKASH_DOCS_BASE}/.",
"data": {"routes": routes},
"next_steps": ["akash_docs_fetch(route_or_url=...)"],
}
def akash_docs_fetch(
self,
*,
route_or_url: str,
source: str = "auto",
ref: str = AKASH_DOCS_GITHUB_REF_DEFAULT,
max_chars: int = 12_000,
refresh: bool = False,
strip_frontmatter: bool = True,
) -> Dict[str, Any]:
route, canonical_url = _normalize_route(route_or_url)
source_norm = (source or "auto").strip().lower()
if source_norm not in ("auto", "github", "site"):
raise ValueError("source must be one of: auto, github, site")
max_chars_int = max(0, int(max_chars))
# Avoid flooding clients; open content_path for full content.
max_chars_int = min(max_chars_int, max(2_000, _max_bytes() - 8_000))
cache_key = f"{source_norm}:{ref}:{route or canonical_url}"
cached = self.store.get(cache_key)
if cached and not refresh:
content = Path(cached.content_path).read_text(encoding="utf-8")
if strip_frontmatter and cached.source == "github":
content = _strip_frontmatter(content)
truncated = len(content) > max_chars_int
return {
"ok": True,
"summary": "Returned cached docs content.",
"data": {
"source": cached.source,
"route": cached.route,
"url": cached.url,
"ref": cached.ref,
"cached": True,
"fetched_at": cached.fetched_at,
"content": content[:max_chars_int],
"truncated": truncated,
"content_path": cached.content_path,
},
"next_steps": ["Set refresh=true to refetch."],
}
attempted: List[Dict[str, Any]] = []
def _try_github() -> Optional[Tuple[str, str, str]]:
try:
md, raw_url, repo_path = _fetch_markdown_from_github(route, ref=ref)
return md, raw_url, repo_path
except urllib.error.HTTPError as e:
attempted.append({"source": "github", "status": getattr(e, "code", None), "detail": str(e)})
return None
def _try_site() -> Optional[Tuple[str, str]]:
try:
html = _http_get(canonical_url)
article = _extract_article_html(html)
text = _html_to_text(article)
return text, canonical_url
except urllib.error.HTTPError as e:
attempted.append({"source": "site", "status": getattr(e, "code", None), "detail": str(e)})
return None
content: str
final_source: str
final_url: str
extra: Dict[str, Any] = {}
if source_norm in ("auto", "github"):
gh = _try_github()
if gh:
content, final_url, repo_path = gh
final_source = "github"
extra["repo_path"] = repo_path
elif source_norm == "github":
raise ValueError("GitHub fetch failed; try source='site' or verify the route/ref.")
else:
site = _try_site()
if not site:
raise ValueError(f"Fetch failed for route_or_url={route_or_url!r}. Attempts: {attempted}")
content, final_url = site
final_source = "site"
else:
site = _try_site()
if not site:
raise ValueError(f"Site fetch failed for route_or_url={route_or_url!r}. Attempts: {attempted}")
content, final_url = site
final_source = "site"
cached_doc = self.store.save(
cache_key=cache_key,
source=final_source,
route=route,
url=final_url,
ref=ref,
content=content,
)
content_view = content
if strip_frontmatter and final_source == "github":
content_view = _strip_frontmatter(content_view)
truncated = len(content_view) > max_chars_int
content_out = content_view[:max_chars_int]
return {
"ok": True,
"summary": f"Fetched docs via {final_source}.",
"data": {
"source": final_source,
"route": route,
"url": final_url,
"ref": ref,
"cached": False,
"fetched_at": cached_doc.fetched_at,
"content": content_out,
"truncated": truncated,
"content_path": cached_doc.content_path,
"attempts": attempted,
**extra,
},
"next_steps": [
"akash_docs_search(query=..., refresh=false)",
],
}
def akash_docs_search(
self,
*,
query: str,
limit: int = 10,
refresh: bool = False,
ref: str = AKASH_DOCS_GITHUB_REF_DEFAULT,
) -> Dict[str, Any]:
q = (query or "").strip()
if not q:
raise ValueError("query is required")
limit = max(1, min(50, int(limit)))
routes = _discover_routes_from_docs_index()
hits: List[Dict[str, Any]] = []
for route in routes:
doc = self.akash_docs_fetch(
route_or_url=route,
source="github",
ref=ref,
max_chars=0, # search reads full content from content_path
refresh=refresh,
strip_frontmatter=True,
)
data = doc.get("data") or {}
content_path = data.get("content_path")
if not content_path:
continue
try:
content = Path(str(content_path)).read_text(encoding="utf-8")
content = _strip_frontmatter(content)
except Exception:
continue
idx = content.lower().find(q.lower())
if idx == -1:
continue
start = max(0, idx - 80)
end = min(len(content), idx + 160)
snippet = content[start:end].replace("\n", " ").strip()
hits.append(
{
"route": route,
"url": data.get("url"),
"source": data.get("source"),
"snippet": snippet,
}
)
if len(hits) >= limit:
break
return {
"ok": True,
"summary": f"Found {len(hits)} hit(s) across {len(routes)} route(s).",
"data": {"query": q, "hits": hits, "routes_searched": len(routes)},
"next_steps": ["akash_docs_fetch(route_or_url=hits[0].route)"],
}
def akash_sdl_snippet(
self,
*,
service_name: str,
container_image: str,
port: int,
cpu_units: float = 0.5,
memory_size: str = "512Mi",
storage_size: str = "512Mi",
denom: str = "uakt",
price_amount: int = 100,
) -> Dict[str, Any]:
svc = (service_name or "").strip()
img = (container_image or "").strip()
if not svc:
raise ValueError("service_name is required")
if not img:
raise ValueError("container_image is required")
port_int = int(port)
if port_int <= 0 or port_int > 65535:
raise ValueError("port must be 1..65535")
sdl = f"""version: \"2.0\"
services:
{svc}:
image: {img}
expose:
- port: {port_int}
to:
- global: true
profiles:
compute:
{svc}:
resources:
cpu:
units: {cpu_units}
memory:
size: {memory_size}
storage:
size: {storage_size}
placement:
akash:
pricing:
{svc}:
denom: {denom}
amount: {int(price_amount)}
deployment:
{svc}:
akash:
profile: {svc}
count: 1
"""
return {
"ok": True,
"summary": "Generated an Akash SDL template.",
"data": {
"service_name": svc,
"container_image": img,
"port": port_int,
"sdl": sdl,
},
"next_steps": [
"Save as deploy.yaml and deploy via Akash Console or akash CLI.",
],
}
TOOLS: List[Dict[str, Any]] = [
{
"name": "akash_docs_list_routes",
"description": "Discover common Akash docs routes by scraping https://akash.network/docs/ (SSR HTML).",
"inputSchema": {"type": "object", "properties": {}},
},
{
"name": "akash_docs_fetch",
"description": "Fetch an Akash docs page (prefers GitHub markdown in akash-network/website-revamp; falls back to site HTML).",
"inputSchema": {
"type": "object",
"properties": {
"route_or_url": {"type": "string"},
"source": {
"type": "string",
"description": "auto|github|site",
"default": "auto",
},
"ref": {"type": "string", "default": AKASH_DOCS_GITHUB_REF_DEFAULT},
"max_chars": {"type": "integer", "default": 12000},
"refresh": {"type": "boolean", "default": False},
"strip_frontmatter": {"type": "boolean", "default": True},
},
"required": ["route_or_url"],
},
},
{
"name": "akash_docs_search",
"description": "Keyword search across routes discovered from /docs (fetches + caches GitHub markdown).",
"inputSchema": {
"type": "object",
"properties": {
"query": {"type": "string"},
"limit": {"type": "integer", "default": 10},
"refresh": {"type": "boolean", "default": False},
"ref": {"type": "string", "default": AKASH_DOCS_GITHUB_REF_DEFAULT},
},
"required": ["query"],
},
},
{
"name": "akash_sdl_snippet",
"description": "Generate a minimal Akash SDL manifest for a single service exposing one port.",
"inputSchema": {
"type": "object",
"properties": {
"service_name": {"type": "string"},
"container_image": {"type": "string"},
"port": {"type": "integer"},
"cpu_units": {"type": "number", "default": 0.5},
"memory_size": {"type": "string", "default": "512Mi"},
"storage_size": {"type": "string", "default": "512Mi"},
"denom": {"type": "string", "default": "uakt"},
"price_amount": {"type": "integer", "default": 100},
},
"required": ["service_name", "container_image", "port"],
},
},
]
class StdioJsonRpc:
def __init__(self) -> None:
self._in = sys.stdin.buffer
self._out = sys.stdout.buffer
self._mode: str | None = None # "headers" | "line"
def read_message(self) -> Optional[Dict[str, Any]]:
while True:
if self._mode == "line":
line = self._in.readline()
if not line:
return None
raw = line.decode("utf-8", "replace").strip()
if not raw:
continue
try:
msg = json.loads(raw)
except Exception:
continue
if isinstance(msg, dict):
return msg
continue
first = self._in.readline()
if not first:
return None
if first in (b"\r\n", b"\n"):
continue
# Auto-detect newline-delimited JSON framing.
if self._mode is None and first.lstrip().startswith(b"{"):
try:
msg = json.loads(first.decode("utf-8", "replace"))
except Exception:
msg = None
if isinstance(msg, dict):
self._mode = "line"
return msg
headers: Dict[str, str] = {}
try:
text = first.decode("utf-8", "replace").strip()
except Exception:
continue
if ":" not in text:
continue
k, v = text.split(":", 1)
headers[k.lower().strip()] = v.strip()
while True:
line = self._in.readline()
if not line:
return None
if line in (b"\r\n", b"\n"):
break
try:
text = line.decode("utf-8", "replace").strip()
except Exception:
continue
if ":" not in text:
continue
k, v = text.split(":", 1)
headers[k.lower().strip()] = v.strip()
if "content-length" not in headers:
return None
try:
length = int(headers["content-length"])
except ValueError:
return None
body = self._in.read(length)
if not body:
return None
self._mode = "headers"
msg = json.loads(body.decode("utf-8", "replace"))
if isinstance(msg, dict):
return msg
return None
def write_message(self, message: Dict[str, Any]) -> None:
if self._mode == "line":
payload = json.dumps(
message, ensure_ascii=False, separators=(",", ":"), default=str
).encode("utf-8")
self._out.write(payload + b"\n")
self._out.flush()
return
body = json.dumps(message, ensure_ascii=False, separators=(",", ":")).encode(
"utf-8"
)
header = f"Content-Length: {len(body)}\r\n\r\n".encode("utf-8")
self._out.write(header)
self._out.write(body)
self._out.flush()
def main() -> None:
tools = AkashDocsTools()
rpc = StdioJsonRpc()
handlers: Dict[str, Callable[[Dict[str, Any]], Dict[str, Any]]] = {
"akash_docs_list_routes": lambda a: tools.akash_docs_list_routes(),
"akash_docs_fetch": lambda a: tools.akash_docs_fetch(**a),
"akash_docs_search": lambda a: tools.akash_docs_search(**a),
"akash_sdl_snippet": lambda a: tools.akash_sdl_snippet(**a),
}
while True:
msg = rpc.read_message()
if msg is None:
return
method = msg.get("method")
msg_id = msg.get("id")
params = msg.get("params") or {}
try:
if method == "initialize":
result = {
"protocolVersion": "2024-11-05",
"serverInfo": {"name": "akash_docs", "version": "0.1.0"},
"capabilities": {"tools": {}},
}
rpc.write_message({"jsonrpc": "2.0", "id": msg_id, "result": result})
continue
if method == "tools/list":
rpc.write_message(
{"jsonrpc": "2.0", "id": msg_id, "result": {"tools": TOOLS}}
)
continue
if method == "tools/call":
tool_name = str(params.get("name") or "")
args = params.get("arguments") or {}
if tool_name not in handlers:
rpc.write_message(
{
"jsonrpc": "2.0",
"id": msg_id,
"result": _mcp_text_result(
f"Unknown tool: {tool_name}\nKnown tools: {', '.join(sorted(handlers.keys()))}",
is_error=True,
),
}
)
continue
try:
payload = handlers[tool_name](args)
# Split payload: meta JSON + optional raw content.
# If payload["data"]["content"] exists, emit it as a second text block for readability.
data = payload.get("data") if isinstance(payload, dict) else None
content_text = None
if isinstance(data, dict) and isinstance(data.get("content"), str):
content_text = data["content"]
data = dict(data)
data.pop("content", None)
payload = dict(payload)
payload["data"] = data
blocks = [json.dumps(payload, ensure_ascii=False, indent=2)]
if content_text:
blocks.append(content_text)
result: Dict[str, Any] = {
"content": [{"type": "text", "text": b} for b in blocks]
}
rpc.write_message({"jsonrpc": "2.0", "id": msg_id, "result": result})
except Exception as e: # noqa: BLE001
rpc.write_message(
{
"jsonrpc": "2.0",
"id": msg_id,
"result": _mcp_text_result(
f"Error: {e}",
is_error=True,
),
}
)
continue
# Ignore notifications.
if msg_id is None:
continue
rpc.write_message(
{
"jsonrpc": "2.0",
"id": msg_id,
"result": _mcp_text_result(
f"Unsupported method: {method}",
is_error=True,
),
}
)
except Exception as e: # noqa: BLE001
# Last-resort: avoid crashing the server.
if msg_id is not None:
rpc.write_message(
{
"jsonrpc": "2.0",
"id": msg_id,
"result": _mcp_text_result(f"fatal error: {e}", is_error=True),
}
)

View File

@@ -0,0 +1,11 @@
"""
cloudflare_safe MCP server.
Summary-first Cloudflare tooling with hard output caps and default redaction.
"""
from __future__ import annotations
__all__ = ["__version__"]
__version__ = "0.1.0"

View File

@@ -0,0 +1,6 @@
from __future__ import annotations
from .server import main
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,496 @@
from __future__ import annotations
import hashlib
import json
import os
import urllib.error
import urllib.parse
import urllib.request
from dataclasses import dataclass
from datetime import datetime, timezone
from pathlib import Path
from typing import (
Any,
Dict,
Iterable,
List,
Mapping,
MutableMapping,
Optional,
Sequence,
Tuple,
)
CF_API_BASE = "https://api.cloudflare.com/client/v4"
def utc_now_iso() -> str:
return datetime.now(timezone.utc).isoformat()
def stable_hash(data: Any) -> str:
blob = json.dumps(
data, sort_keys=True, separators=(",", ":"), ensure_ascii=False
).encode("utf-8")
return hashlib.sha256(blob).hexdigest()
class CloudflareError(RuntimeError):
pass
@dataclass(frozen=True)
class CloudflareContext:
api_token: str
account_id: str
@staticmethod
def from_env() -> "CloudflareContext":
api_token = (
os.getenv("CLOUDFLARE_API_TOKEN")
or os.getenv("CF_API_TOKEN")
or os.getenv("CLOUDFLARE_TOKEN")
or ""
).strip()
account_id = (
os.getenv("CLOUDFLARE_ACCOUNT_ID") or os.getenv("CF_ACCOUNT_ID") or ""
).strip()
if not api_token:
raise CloudflareError(
"Missing Cloudflare API token. Set CLOUDFLARE_API_TOKEN (or CF_API_TOKEN)."
)
if not account_id:
raise CloudflareError(
"Missing Cloudflare account id. Set CLOUDFLARE_ACCOUNT_ID (or CF_ACCOUNT_ID)."
)
return CloudflareContext(api_token=api_token, account_id=account_id)
class CloudflareClient:
def __init__(self, *, api_token: str) -> None:
self.api_token = api_token
def _request(
self,
method: str,
path: str,
*,
params: Optional[Mapping[str, str]] = None,
) -> Dict[str, Any]:
url = f"{CF_API_BASE}{path}"
if params:
url = f"{url}?{urllib.parse.urlencode(params)}"
req = urllib.request.Request(
url=url,
method=method,
headers={
"Authorization": f"Bearer {self.api_token}",
"Accept": "application/json",
"Content-Type": "application/json",
},
)
try:
with urllib.request.urlopen(req, timeout=30) as resp:
raw = resp.read()
except urllib.error.HTTPError as e:
raw = e.read() if hasattr(e, "read") else b""
detail = raw.decode("utf-8", "replace")
raise CloudflareError(
f"Cloudflare API HTTP {e.code} for {path}: {detail}"
) from e
except urllib.error.URLError as e:
raise CloudflareError(
f"Cloudflare API request failed for {path}: {e}"
) from e
try:
data = json.loads(raw.decode("utf-8", "replace"))
except json.JSONDecodeError:
raise CloudflareError(
f"Cloudflare API returned non-JSON for {path}: {raw[:200]!r}"
)
if not data.get("success", True):
raise CloudflareError(
f"Cloudflare API error for {path}: {data.get('errors')}"
)
return data
def paginate(
self,
path: str,
*,
params: Optional[Mapping[str, str]] = None,
per_page: int = 100,
max_pages: int = 5,
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
"""
Fetch a paginated Cloudflare endpoint.
Returns (results, result_info).
"""
results: List[Dict[str, Any]] = []
page = 1
last_info: Dict[str, Any] = {}
while True:
merged_params: Dict[str, str] = {
"page": str(page),
"per_page": str(per_page),
}
if params:
merged_params.update({k: str(v) for k, v in params.items()})
data = self._request("GET", path, params=merged_params)
batch = data.get("result") or []
if not isinstance(batch, list):
batch = [batch]
results.extend(batch)
last_info = data.get("result_info") or {}
total_pages = int(last_info.get("total_pages") or 1)
if page >= total_pages or page >= max_pages:
break
page += 1
return results, last_info
def list_zones(self) -> List[Dict[str, Any]]:
zones, _info = self.paginate("/zones", max_pages=2)
return zones
def list_dns_records_summary(
self, zone_id: str, *, max_pages: int = 1
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
return self.paginate(f"/zones/{zone_id}/dns_records", max_pages=max_pages)
def list_tunnels(self, account_id: str) -> List[Dict[str, Any]]:
tunnels, _info = self.paginate(
f"/accounts/{account_id}/cfd_tunnel", max_pages=2
)
return tunnels
def list_tunnel_connections(
self, account_id: str, tunnel_id: str
) -> List[Dict[str, Any]]:
data = self._request(
"GET", f"/accounts/{account_id}/cfd_tunnel/{tunnel_id}/connections"
)
result = data.get("result") or []
return result if isinstance(result, list) else [result]
def list_access_apps(self, account_id: str) -> List[Dict[str, Any]]:
apps, _info = self.paginate(f"/accounts/{account_id}/access/apps", max_pages=3)
return apps
def list_access_policies(
self, account_id: str, app_id: str
) -> List[Dict[str, Any]]:
policies, _info = self.paginate(
f"/accounts/{account_id}/access/apps/{app_id}/policies",
max_pages=3,
)
return policies
@dataclass(frozen=True)
class SnapshotMeta:
snapshot_id: str
created_at: str
scopes: List[str]
snapshot_path: str
class SnapshotStore:
def __init__(self, root_dir: Path) -> None:
self.root_dir = root_dir
self.snapshots_dir = root_dir / "snapshots"
self.diffs_dir = root_dir / "diffs"
self.snapshots_dir.mkdir(parents=True, exist_ok=True)
self.diffs_dir.mkdir(parents=True, exist_ok=True)
self._index: Dict[str, SnapshotMeta] = {}
def get(self, snapshot_id: str) -> SnapshotMeta:
if snapshot_id not in self._index:
raise CloudflareError(f"Unknown snapshot_id: {snapshot_id}")
return self._index[snapshot_id]
def load_snapshot(self, snapshot_id: str) -> Dict[str, Any]:
meta = self.get(snapshot_id)
return json.loads(Path(meta.snapshot_path).read_text(encoding="utf-8"))
def create_snapshot(
self,
*,
client: CloudflareClient,
ctx: CloudflareContext,
scopes: Sequence[str],
zone_id: Optional[str] = None,
zone_name: Optional[str] = None,
dns_max_pages: int = 1,
) -> Tuple[SnapshotMeta, Dict[str, Any]]:
scopes_norm = sorted(set(scopes))
created_at = utc_now_iso()
zones = client.list_zones()
zones_min = [
{
"id": z.get("id"),
"name": z.get("name"),
"status": z.get("status"),
"paused": z.get("paused"),
}
for z in zones
]
selected_zone_id = zone_id
if not selected_zone_id and zone_name:
for z in zones_min:
if z.get("name") == zone_name:
selected_zone_id = str(z.get("id"))
break
snapshot: Dict[str, Any] = {
"meta": {
"snapshot_id": "",
"created_at": created_at,
"account_id": ctx.account_id,
"scopes": scopes_norm,
},
"zones": zones_min,
}
if "tunnels" in scopes_norm:
tunnels = client.list_tunnels(ctx.account_id)
tunnels_min: List[Dict[str, Any]] = []
for t in tunnels:
tid = t.get("id")
name = t.get("name")
status = t.get("status")
connector_count: Optional[int] = None
last_seen: Optional[str] = None
if tid and status != "deleted":
conns = client.list_tunnel_connections(ctx.account_id, str(tid))
connector_count = len(conns)
# Pick the most recent 'opened_at' if present.
opened = [c.get("opened_at") for c in conns if isinstance(c, dict)]
opened = [o for o in opened if isinstance(o, str)]
last_seen = max(opened) if opened else None
tunnels_min.append(
{
"id": tid,
"name": name,
"status": status,
"created_at": t.get("created_at"),
"deleted_at": t.get("deleted_at"),
"connector_count": connector_count,
"last_seen": last_seen,
}
)
snapshot["tunnels"] = tunnels_min
if "access_apps" in scopes_norm:
apps = client.list_access_apps(ctx.account_id)
apps_min = [
{
"id": a.get("id"),
"name": a.get("name"),
"domain": a.get("domain"),
"type": a.get("type"),
"created_at": a.get("created_at"),
"updated_at": a.get("updated_at"),
}
for a in apps
]
snapshot["access_apps"] = apps_min
if "dns" in scopes_norm:
if selected_zone_id:
records, info = client.list_dns_records_summary(
selected_zone_id, max_pages=dns_max_pages
)
records_min = [
{
"id": r.get("id"),
"type": r.get("type"),
"name": r.get("name"),
"content": r.get("content"),
"proxied": r.get("proxied"),
"ttl": r.get("ttl"),
}
for r in records
]
snapshot["dns"] = {
"zone_id": selected_zone_id,
"zone_name": zone_name,
"result_info": info,
"records_sample": records_min,
}
else:
snapshot["dns"] = {
"note": "dns scope requested but no zone_id/zone_name provided; only zones list included",
}
snapshot_id = f"cf_{created_at.replace(':', '').replace('-', '').replace('.', '')}_{stable_hash(snapshot)[:10]}"
snapshot["meta"]["snapshot_id"] = snapshot_id
path = self.snapshots_dir / f"{snapshot_id}.json"
path.write_text(
json.dumps(snapshot, indent=2, ensure_ascii=False), encoding="utf-8"
)
meta = SnapshotMeta(
snapshot_id=snapshot_id,
created_at=created_at,
scopes=scopes_norm,
snapshot_path=str(path),
)
self._index[snapshot_id] = meta
return meta, snapshot
def diff(
self,
*,
from_snapshot_id: str,
to_snapshot_id: str,
scopes: Optional[Sequence[str]] = None,
) -> Dict[str, Any]:
before = self.load_snapshot(from_snapshot_id)
after = self.load_snapshot(to_snapshot_id)
scopes_before = set(before.get("meta", {}).get("scopes") or [])
scopes_after = set(after.get("meta", {}).get("scopes") or [])
scopes_all = sorted(scopes_before | scopes_after)
scopes_use = sorted(set(scopes or scopes_all))
def index_by_id(
items: Iterable[Mapping[str, Any]],
) -> Dict[str, Dict[str, Any]]:
out: Dict[str, Dict[str, Any]] = {}
for it in items:
_id = it.get("id")
if _id is None:
continue
out[str(_id)] = dict(it)
return out
diff_out: Dict[str, Any] = {
"from": from_snapshot_id,
"to": to_snapshot_id,
"scopes": scopes_use,
"changes": {},
}
for scope in scopes_use:
if scope not in {"tunnels", "access_apps", "zones"}:
continue
b_items = before.get(scope) or []
a_items = after.get(scope) or []
if not isinstance(b_items, list) or not isinstance(a_items, list):
continue
b_map = index_by_id(b_items)
a_map = index_by_id(a_items)
added = [a_map[k] for k in sorted(set(a_map) - set(b_map))]
removed = [b_map[k] for k in sorted(set(b_map) - set(a_map))]
changed: List[Dict[str, Any]] = []
for k in sorted(set(a_map) & set(b_map)):
if stable_hash(a_map[k]) != stable_hash(b_map[k]):
changed.append({"id": k, "before": b_map[k], "after": a_map[k]})
diff_out["changes"][scope] = {
"added": [{"id": x.get("id"), "name": x.get("name")} for x in added],
"removed": [
{"id": x.get("id"), "name": x.get("name")} for x in removed
],
"changed": [
{"id": x.get("id"), "name": x.get("after", {}).get("name")}
for x in changed
],
"counts": {
"added": len(added),
"removed": len(removed),
"changed": len(changed),
},
}
diff_path = self.diffs_dir / f"{from_snapshot_id}_to_{to_snapshot_id}.json"
diff_path.write_text(
json.dumps(diff_out, indent=2, ensure_ascii=False),
encoding="utf-8",
)
diff_out["diff_path"] = str(diff_path)
return diff_out
def parse_cloudflared_config_ingress(config_text: str) -> List[Dict[str, str]]:
"""
Best-effort parser for cloudflared YAML config ingress rules.
We intentionally avoid a YAML dependency; this extracts common patterns:
- hostname: example.com
service: http://127.0.0.1:8080
"""
rules: List[Dict[str, str]] = []
lines = config_text.splitlines()
i = 0
while i < len(lines):
line = lines[i]
stripped = line.lstrip()
if not stripped.startswith("-"):
i += 1
continue
after_dash = stripped[1:].lstrip()
if not after_dash.startswith("hostname:"):
i += 1
continue
hostname = after_dash[len("hostname:") :].strip().strip('"').strip("'")
base_indent = len(line) - len(line.lstrip())
service = ""
j = i + 1
while j < len(lines):
next_line = lines[j]
if next_line.strip() == "":
j += 1
continue
next_indent = len(next_line) - len(next_line.lstrip())
if next_indent <= base_indent:
break
next_stripped = next_line.lstrip()
if next_stripped.startswith("service:"):
service = next_stripped[len("service:") :].strip().strip('"').strip("'")
break
j += 1
rules.append({"hostname": hostname, "service": service})
i = j
return rules
def ingress_summary_from_file(
*,
config_path: str,
max_rules: int = 50,
) -> Dict[str, Any]:
path = Path(config_path)
if not path.exists():
raise CloudflareError(f"cloudflared config not found: {config_path}")
text = path.read_text(encoding="utf-8", errors="replace")
rules = parse_cloudflared_config_ingress(text)
hostnames = sorted({r["hostname"] for r in rules if r.get("hostname")})
return {
"config_path": config_path,
"ingress_rule_count": len(rules),
"hostnames": hostnames[:max_rules],
"rules_sample": rules[:max_rules],
"truncated": len(rules) > max_rules,
}

View File

@@ -0,0 +1,725 @@
from __future__ import annotations
import json
import os
import sys
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple
from .cloudflare_api import (
CloudflareClient,
CloudflareContext,
CloudflareError,
SnapshotStore,
ingress_summary_from_file,
)
MAX_BYTES_DEFAULT = 32_000
def _repo_root() -> Path:
# server.py -> cloudflare_safe -> mcp -> <repo root>
return Path(__file__).resolve().parents[3]
def _max_bytes() -> int:
raw = (os.getenv("VM_MCP_MAX_BYTES") or "").strip()
if not raw:
return MAX_BYTES_DEFAULT
try:
return max(4_096, int(raw))
except ValueError:
return MAX_BYTES_DEFAULT
def _redact(obj: Any) -> Any:
sensitive_keys = ("token", "secret", "password", "private", "key", "certificate")
if isinstance(obj, dict):
out: Dict[str, Any] = {}
for k, v in obj.items():
if any(s in str(k).lower() for s in sensitive_keys):
out[k] = "<REDACTED>"
else:
out[k] = _redact(v)
return out
if isinstance(obj, list):
return [_redact(v) for v in obj]
if isinstance(obj, str):
if obj.startswith("ghp_") or obj.startswith("github_pat_"):
return "<REDACTED>"
return obj
return obj
def _safe_json(payload: Dict[str, Any]) -> str:
payload = _redact(payload)
raw = json.dumps(payload, ensure_ascii=False, separators=(",", ":"))
if len(raw.encode("utf-8")) <= _max_bytes():
return json.dumps(payload, ensure_ascii=False, indent=2)
# Truncate: keep only summary + next_steps.
truncated = {
"ok": payload.get("ok", True),
"truncated": True,
"summary": payload.get("summary", "Response exceeded max size; truncated."),
"next_steps": payload.get(
"next_steps",
[
"request a narrower scope (e.g., scopes=['tunnels'])",
"request an export path instead of inline content",
],
),
}
return json.dumps(truncated, ensure_ascii=False, indent=2)
def _mcp_text_result(
payload: Dict[str, Any], *, is_error: bool = False
) -> Dict[str, Any]:
result: Dict[str, Any] = {
"content": [{"type": "text", "text": _safe_json(payload)}]
}
if is_error:
result["isError"] = True
return result
def _default_state_dir() -> Path:
return _repo_root() / "archive_runtime" / "cloudflare_mcp"
class CloudflareSafeTools:
def __init__(self) -> None:
self.store = SnapshotStore(
Path(os.getenv("VM_CF_MCP_STATE_DIR") or _default_state_dir())
)
def cf_snapshot(
self,
*,
scopes: Optional[Sequence[str]] = None,
zone_id: Optional[str] = None,
zone_name: Optional[str] = None,
dns_max_pages: int = 1,
) -> Dict[str, Any]:
scopes_use = list(scopes or ["tunnels", "access_apps"])
ctx = CloudflareContext.from_env()
client = CloudflareClient(api_token=ctx.api_token)
meta, snapshot = self.store.create_snapshot(
client=client,
ctx=ctx,
scopes=scopes_use,
zone_id=zone_id,
zone_name=zone_name,
dns_max_pages=dns_max_pages,
)
summary = (
f"Snapshot {meta.snapshot_id} captured "
f"(scopes={','.join(meta.scopes)}) and written to {meta.snapshot_path}."
)
return {
"ok": True,
"summary": summary,
"data": {
"snapshot_id": meta.snapshot_id,
"created_at": meta.created_at,
"scopes": meta.scopes,
"snapshot_path": meta.snapshot_path,
"counts": {
"zones": len(snapshot.get("zones") or []),
"tunnels": len(snapshot.get("tunnels") or []),
"access_apps": len(snapshot.get("access_apps") or []),
},
},
"truncated": False,
"next_steps": [
"cf_config_diff(from_snapshot_id=..., to_snapshot_id=...)",
"cf_export_config(full=false, snapshot_id=...)",
],
}
def cf_refresh(
self,
*,
snapshot_id: str,
scopes: Optional[Sequence[str]] = None,
dns_max_pages: int = 1,
) -> Dict[str, Any]:
before_meta = self.store.get(snapshot_id)
before = self.store.load_snapshot(snapshot_id)
scopes_use = list(scopes or (before.get("meta", {}).get("scopes") or []))
ctx = CloudflareContext.from_env()
client = CloudflareClient(api_token=ctx.api_token)
meta, _snapshot = self.store.create_snapshot(
client=client,
ctx=ctx,
scopes=scopes_use,
zone_id=(before.get("dns") or {}).get("zone_id"),
zone_name=(before.get("dns") or {}).get("zone_name"),
dns_max_pages=dns_max_pages,
)
return {
"ok": True,
"summary": f"Refreshed {before_meta.snapshot_id} -> {meta.snapshot_id} (scopes={','.join(meta.scopes)}).",
"data": {
"from_snapshot_id": before_meta.snapshot_id,
"to_snapshot_id": meta.snapshot_id,
"snapshot_path": meta.snapshot_path,
},
"truncated": False,
"next_steps": [
"cf_config_diff(from_snapshot_id=..., to_snapshot_id=...)",
],
}
def cf_config_diff(
self,
*,
from_snapshot_id: str,
to_snapshot_id: str,
scopes: Optional[Sequence[str]] = None,
) -> Dict[str, Any]:
diff = self.store.diff(
from_snapshot_id=from_snapshot_id,
to_snapshot_id=to_snapshot_id,
scopes=scopes,
)
# Keep the response small; point to diff_path for full detail.
changes = diff.get("changes") or {}
counts = {
scope: (changes.get(scope) or {}).get("counts")
for scope in sorted(changes.keys())
}
return {
"ok": True,
"summary": f"Diff computed and written to {diff.get('diff_path')}.",
"data": {
"from_snapshot_id": from_snapshot_id,
"to_snapshot_id": to_snapshot_id,
"scopes": diff.get("scopes"),
"counts": counts,
"diff_path": diff.get("diff_path"),
},
"truncated": False,
"next_steps": [
"Use filesystem MCP to open diff_path for full details",
"Run cf_export_config(full=false, snapshot_id=...) for a safe export path",
],
}
def cf_export_config(
self,
*,
snapshot_id: Optional[str] = None,
full: bool = False,
scopes: Optional[Sequence[str]] = None,
) -> Dict[str, Any]:
if snapshot_id is None:
snap = self.cf_snapshot(scopes=scopes)
snapshot_id = str((snap.get("data") or {}).get("snapshot_id"))
meta = self.store.get(snapshot_id)
if not full:
return {
"ok": True,
"summary": "Export is summary-first; full config requires full=true.",
"data": {
"snapshot_id": meta.snapshot_id,
"snapshot_path": meta.snapshot_path,
},
"truncated": False,
"next_steps": [
"Use filesystem MCP to open snapshot_path",
"If you truly need inline data, call cf_export_config(full=true, snapshot_id=...)",
],
}
snapshot = self.store.load_snapshot(snapshot_id)
return {
"ok": True,
"summary": "Full snapshot export (redacted + size-capped). Prefer snapshot_path for large data.",
"data": snapshot,
"truncated": False,
"next_steps": [
f"Snapshot file: {meta.snapshot_path}",
],
}
def cf_tunnel_status(
self,
*,
snapshot_id: Optional[str] = None,
tunnel_name: Optional[str] = None,
tunnel_id: Optional[str] = None,
) -> Dict[str, Any]:
if snapshot_id:
snap = self.store.load_snapshot(snapshot_id)
tunnels = snap.get("tunnels") or []
else:
snap = self.cf_snapshot(scopes=["tunnels"])
sid = str((snap.get("data") or {}).get("snapshot_id"))
tunnels = self.store.load_snapshot(sid).get("tunnels") or []
def matches(t: Dict[str, Any]) -> bool:
if tunnel_id and str(t.get("id")) != str(tunnel_id):
return False
if tunnel_name and str(t.get("name")) != str(tunnel_name):
return False
return True
filtered = [t for t in tunnels if isinstance(t, dict) and matches(t)]
if not filtered and (tunnel_id or tunnel_name):
return {
"ok": False,
"summary": "Tunnel not found in snapshot.",
"data": {"tunnel_id": tunnel_id, "tunnel_name": tunnel_name},
"truncated": False,
"next_steps": ["Call cf_snapshot(scopes=['tunnels']) and retry."],
}
connectors = [t.get("connector_count") for t in filtered if isinstance(t, dict)]
connectors = [c for c in connectors if isinstance(c, int)]
return {
"ok": True,
"summary": f"Returned {len(filtered)} tunnel(s).",
"data": {
"tunnels": [
{
"id": t.get("id"),
"name": t.get("name"),
"status": t.get("status"),
"connector_count": t.get("connector_count"),
"last_seen": t.get("last_seen"),
}
for t in filtered
],
"connectors_total": sum(connectors) if connectors else 0,
},
"truncated": False,
"next_steps": [
"For local ingress hostnames, use cf_tunnel_ingress_summary(config_path='/etc/cloudflared/config.yml')",
],
}
def cf_tunnel_ingress_summary(
self,
*,
config_path: str = "/etc/cloudflared/config.yml",
full: bool = False,
max_rules: int = 50,
) -> Dict[str, Any]:
summary = ingress_summary_from_file(
config_path=config_path, max_rules=max_rules
)
if not full:
return {
"ok": True,
"summary": f"Parsed ingress hostnames from {config_path}.",
"data": {
"config_path": summary["config_path"],
"ingress_rule_count": summary["ingress_rule_count"],
"hostnames": summary["hostnames"],
"truncated": summary["truncated"],
},
"truncated": False,
"next_steps": [
"Call cf_tunnel_ingress_summary(full=true, ...) to include service mappings (still capped).",
],
}
return {
"ok": True,
"summary": f"Ingress summary (full=true) for {config_path}.",
"data": summary,
"truncated": False,
"next_steps": [],
}
def cf_access_policy_list(
self,
*,
app_id: Optional[str] = None,
) -> Dict[str, Any]:
ctx = CloudflareContext.from_env()
client = CloudflareClient(api_token=ctx.api_token)
if not app_id:
apps = client.list_access_apps(ctx.account_id)
apps_min = [
{
"id": a.get("id"),
"name": a.get("name"),
"domain": a.get("domain"),
"type": a.get("type"),
}
for a in apps
]
return {
"ok": True,
"summary": f"Returned {len(apps_min)} Access app(s). Provide app_id to list policies.",
"data": {"apps": apps_min},
"truncated": False,
"next_steps": [
"Call cf_access_policy_list(app_id=...)",
],
}
policies = client.list_access_policies(ctx.account_id, app_id)
policies_min = [
{
"id": p.get("id"),
"name": p.get("name"),
"decision": p.get("decision"),
"precedence": p.get("precedence"),
}
for p in policies
]
return {
"ok": True,
"summary": f"Returned {len(policies_min)} policy/policies for app_id={app_id}.",
"data": {"app_id": app_id, "policies": policies_min},
"truncated": False,
"next_steps": [],
}
TOOLS: List[Dict[str, Any]] = [
{
"name": "cf_snapshot",
"description": "Create a summary-first Cloudflare state snapshot (writes JSON to disk; returns snapshot_id + paths).",
"inputSchema": {
"type": "object",
"properties": {
"scopes": {
"type": "array",
"items": {"type": "string"},
"description": "Scopes to fetch (default: ['tunnels','access_apps']). Supported: zones,tunnels,access_apps,dns",
},
"zone_id": {"type": "string"},
"zone_name": {"type": "string"},
"dns_max_pages": {"type": "integer", "default": 1},
},
},
},
{
"name": "cf_refresh",
"description": "Refresh a prior snapshot (creates a new snapshot_id).",
"inputSchema": {
"type": "object",
"properties": {
"snapshot_id": {"type": "string"},
"scopes": {"type": "array", "items": {"type": "string"}},
"dns_max_pages": {"type": "integer", "default": 1},
},
"required": ["snapshot_id"],
},
},
{
"name": "cf_config_diff",
"description": "Diff two snapshots (summary counts inline; full diff written to disk).",
"inputSchema": {
"type": "object",
"properties": {
"from_snapshot_id": {"type": "string"},
"to_snapshot_id": {"type": "string"},
"scopes": {"type": "array", "items": {"type": "string"}},
},
"required": ["from_snapshot_id", "to_snapshot_id"],
},
},
{
"name": "cf_export_config",
"description": "Export snapshot config. Defaults to summary-only; full=true returns redacted + size-capped data.",
"inputSchema": {
"type": "object",
"properties": {
"snapshot_id": {"type": "string"},
"full": {"type": "boolean", "default": False},
"scopes": {"type": "array", "items": {"type": "string"}},
},
},
},
{
"name": "cf_tunnel_status",
"description": "Return tunnel status summary (connector count, last seen).",
"inputSchema": {
"type": "object",
"properties": {
"snapshot_id": {"type": "string"},
"tunnel_name": {"type": "string"},
"tunnel_id": {"type": "string"},
},
},
},
{
"name": "cf_tunnel_ingress_summary",
"description": "Parse cloudflared ingress hostnames from a local config file (never dumps full YAML unless full=true, still capped).",
"inputSchema": {
"type": "object",
"properties": {
"config_path": {
"type": "string",
"default": "/etc/cloudflared/config.yml",
},
"full": {"type": "boolean", "default": False},
"max_rules": {"type": "integer", "default": 50},
},
},
},
{
"name": "cf_access_policy_list",
"description": "List Access apps, or policies for a specific app_id (summary-only).",
"inputSchema": {
"type": "object",
"properties": {
"app_id": {"type": "string"},
},
},
},
]
class StdioJsonRpc:
def __init__(self) -> None:
self._in = sys.stdin.buffer
self._out = sys.stdout.buffer
self._mode: str | None = None # "headers" | "line"
def read_message(self) -> Optional[Dict[str, Any]]:
while True:
if self._mode == "line":
line = self._in.readline()
if not line:
return None
raw = line.decode("utf-8", "replace").strip()
if not raw:
continue
try:
msg = json.loads(raw)
except Exception:
continue
if isinstance(msg, dict):
return msg
continue
first = self._in.readline()
if not first:
return None
if first in (b"\r\n", b"\n"):
continue
# Auto-detect newline-delimited JSON framing.
if self._mode is None and first.lstrip().startswith(b"{"):
try:
msg = json.loads(first.decode("utf-8", "replace"))
except Exception:
msg = None
if isinstance(msg, dict):
self._mode = "line"
return msg
headers: Dict[str, str] = {}
try:
text = first.decode("utf-8", "replace").strip()
except Exception:
continue
if ":" not in text:
continue
k, v = text.split(":", 1)
headers[k.lower().strip()] = v.strip()
while True:
line = self._in.readline()
if not line:
return None
if line in (b"\r\n", b"\n"):
break
try:
text = line.decode("utf-8", "replace").strip()
except Exception:
continue
if ":" not in text:
continue
k, v = text.split(":", 1)
headers[k.lower().strip()] = v.strip()
if "content-length" not in headers:
return None
try:
length = int(headers["content-length"])
except ValueError:
return None
body = self._in.read(length)
if not body:
return None
self._mode = "headers"
msg = json.loads(body.decode("utf-8", "replace"))
if isinstance(msg, dict):
return msg
return None
def write_message(self, message: Dict[str, Any]) -> None:
if self._mode == "line":
payload = json.dumps(
message, ensure_ascii=False, separators=(",", ":"), default=str
).encode("utf-8")
self._out.write(payload + b"\n")
self._out.flush()
return
body = json.dumps(message, ensure_ascii=False, separators=(",", ":")).encode(
"utf-8"
)
header = f"Content-Length: {len(body)}\r\n\r\n".encode("utf-8")
self._out.write(header)
self._out.write(body)
self._out.flush()
def main() -> None:
tools = CloudflareSafeTools()
rpc = StdioJsonRpc()
handlers: Dict[str, Callable[[Dict[str, Any]], Dict[str, Any]]] = {
"cf_snapshot": lambda a: tools.cf_snapshot(**a),
"cf_refresh": lambda a: tools.cf_refresh(**a),
"cf_config_diff": lambda a: tools.cf_config_diff(**a),
"cf_export_config": lambda a: tools.cf_export_config(**a),
"cf_tunnel_status": lambda a: tools.cf_tunnel_status(**a),
"cf_tunnel_ingress_summary": lambda a: tools.cf_tunnel_ingress_summary(**a),
"cf_access_policy_list": lambda a: tools.cf_access_policy_list(**a),
}
while True:
msg = rpc.read_message()
if msg is None:
return
method = msg.get("method")
msg_id = msg.get("id")
params = msg.get("params") or {}
try:
if method == "initialize":
result = {
"protocolVersion": "2024-11-05",
"serverInfo": {"name": "cloudflare_safe", "version": "0.1.0"},
"capabilities": {"tools": {}},
}
rpc.write_message({"jsonrpc": "2.0", "id": msg_id, "result": result})
continue
if method == "tools/list":
rpc.write_message(
{"jsonrpc": "2.0", "id": msg_id, "result": {"tools": TOOLS}}
)
continue
if method == "tools/call":
tool_name = str(params.get("name") or "")
args = params.get("arguments") or {}
if tool_name not in handlers:
rpc.write_message(
{
"jsonrpc": "2.0",
"id": msg_id,
"result": _mcp_text_result(
{
"ok": False,
"summary": f"Unknown tool: {tool_name}",
"data": {"known_tools": sorted(handlers.keys())},
"truncated": False,
"next_steps": ["Call tools/list"],
},
is_error=True,
),
}
)
continue
try:
payload = handlers[tool_name](args)
rpc.write_message(
{
"jsonrpc": "2.0",
"id": msg_id,
"result": _mcp_text_result(payload),
}
)
except CloudflareError as e:
rpc.write_message(
{
"jsonrpc": "2.0",
"id": msg_id,
"result": _mcp_text_result(
{
"ok": False,
"summary": str(e),
"truncated": False,
"next_steps": [
"Verify CLOUDFLARE_API_TOKEN and CLOUDFLARE_ACCOUNT_ID are set",
"Retry with a narrower scope",
],
},
is_error=True,
),
}
)
except Exception as e: # noqa: BLE001
rpc.write_message(
{
"jsonrpc": "2.0",
"id": msg_id,
"result": _mcp_text_result(
{
"ok": False,
"summary": f"Unhandled error: {e}",
"truncated": False,
"next_steps": ["Retry with a narrower scope"],
},
is_error=True,
),
}
)
continue
# Ignore notifications.
if msg_id is None:
continue
rpc.write_message(
{
"jsonrpc": "2.0",
"id": msg_id,
"result": _mcp_text_result(
{
"ok": False,
"summary": f"Unsupported method: {method}",
"truncated": False,
},
is_error=True,
),
}
)
except Exception as e: # noqa: BLE001
# Last-resort: avoid crashing the server.
if msg_id is not None:
rpc.write_message(
{
"jsonrpc": "2.0",
"id": msg_id,
"result": _mcp_text_result(
{
"ok": False,
"summary": f"fatal error: {e}",
"truncated": False,
},
),
}
)

View File

@@ -0,0 +1,6 @@
from __future__ import annotations
from .server import main
if __name__ == "__main__":
main()

386
mcp/oracle_answer/server.py Normal file
View File

@@ -0,0 +1,386 @@
from __future__ import annotations
import asyncio
import json
import os
import sys
from typing import Any, Callable, Dict, List, Optional
from layer0 import layer0_entry
from layer0.shadow_classifier import ShadowEvalResult
from .tool import OracleAnswerTool
MAX_BYTES_DEFAULT = 32_000
def _max_bytes() -> int:
raw = (os.getenv("VM_MCP_MAX_BYTES") or "").strip()
if not raw:
return MAX_BYTES_DEFAULT
try:
return max(4_096, int(raw))
except ValueError:
return MAX_BYTES_DEFAULT
def _redact(obj: Any) -> Any:
sensitive_keys = ("token", "secret", "password", "private", "key", "certificate")
if isinstance(obj, dict):
out: Dict[str, Any] = {}
for k, v in obj.items():
if any(s in str(k).lower() for s in sensitive_keys):
out[k] = "<REDACTED>"
else:
out[k] = _redact(v)
return out
if isinstance(obj, list):
return [_redact(v) for v in obj]
if isinstance(obj, str):
if obj.startswith("ghp_") or obj.startswith("github_pat_"):
return "<REDACTED>"
return obj
return obj
def _safe_json(payload: Dict[str, Any]) -> str:
payload = _redact(payload)
raw = json.dumps(payload, ensure_ascii=False, separators=(",", ":"), default=str)
if len(raw.encode("utf-8")) <= _max_bytes():
return json.dumps(payload, ensure_ascii=False, indent=2, default=str)
truncated = {
"ok": payload.get("ok", True),
"truncated": True,
"summary": payload.get("summary", "Response exceeded max size; truncated."),
"next_steps": payload.get(
"next_steps",
["request narrower outputs (e.g., fewer frameworks or shorter question)"],
),
}
return json.dumps(truncated, ensure_ascii=False, indent=2, default=str)
def _mcp_text_result(
payload: Dict[str, Any], *, is_error: bool = False
) -> Dict[str, Any]:
result: Dict[str, Any] = {
"content": [{"type": "text", "text": _safe_json(payload)}]
}
if is_error:
result["isError"] = True
return result
TOOLS: List[Dict[str, Any]] = [
{
"name": "oracle_answer",
"description": "Answer a compliance/security question (optionally via NVIDIA LLM) and map to frameworks.",
"inputSchema": {
"type": "object",
"properties": {
"question": {
"type": "string",
"description": "The question to answer.",
},
"frameworks": {
"type": "array",
"items": {"type": "string"},
"description": "Frameworks to reference (e.g., ['NIST-CSF','ISO-27001','GDPR']).",
},
"mode": {
"type": "string",
"enum": ["strict", "advisory"],
"default": "strict",
"description": "strict=conservative, advisory=exploratory.",
},
"local_only": {
"type": "boolean",
"description": "If true, skip NVIDIA API calls (uses local-only mode). Defaults to true when NVIDIA_API_KEY is missing.",
},
},
"required": ["question"],
},
}
]
class OracleAnswerTools:
async def oracle_answer(
self,
*,
question: str,
frameworks: Optional[List[str]] = None,
mode: str = "strict",
local_only: Optional[bool] = None,
) -> Dict[str, Any]:
routing_action, shadow = layer0_entry(question)
if routing_action != "HANDOFF_TO_LAYER1":
return _layer0_payload(routing_action, shadow)
local_only_use = (
bool(local_only)
if local_only is not None
else not bool((os.getenv("NVIDIA_API_KEY") or "").strip())
)
try:
tool = OracleAnswerTool(
default_frameworks=frameworks,
use_local_only=local_only_use,
)
except Exception as e: # noqa: BLE001
return {
"ok": False,
"summary": str(e),
"data": {
"local_only": local_only_use,
"has_nvidia_api_key": bool(
(os.getenv("NVIDIA_API_KEY") or "").strip()
),
},
"truncated": False,
"next_steps": [
"Set NVIDIA_API_KEY to enable live answers",
"Or call oracle_answer(local_only=true, ...)",
],
}
resp = await tool.answer(question=question, frameworks=frameworks, mode=mode)
return {
"ok": True,
"summary": "Oracle answer generated.",
"data": {
"question": question,
"mode": mode,
"frameworks": frameworks or tool.default_frameworks,
"local_only": local_only_use,
"model": resp.model,
"answer": resp.answer,
"framework_hits": resp.framework_hits,
"reasoning": resp.reasoning,
},
"truncated": False,
"next_steps": [
"If the answer is incomplete, add more specifics to the question or include more frameworks.",
],
}
class StdioJsonRpc:
def __init__(self) -> None:
self._in = sys.stdin.buffer
self._out = sys.stdout.buffer
self._mode: str | None = None # "headers" | "line"
def read_message(self) -> Optional[Dict[str, Any]]:
while True:
if self._mode == "line":
line = self._in.readline()
if not line:
return None
raw = line.decode("utf-8", "replace").strip()
if not raw:
continue
try:
msg = json.loads(raw)
except Exception:
continue
if isinstance(msg, dict):
return msg
continue
first = self._in.readline()
if not first:
return None
if first in (b"\r\n", b"\n"):
continue
# Auto-detect newline-delimited JSON framing.
if self._mode is None and first.lstrip().startswith(b"{"):
try:
msg = json.loads(first.decode("utf-8", "replace"))
except Exception:
msg = None
if isinstance(msg, dict):
self._mode = "line"
return msg
headers: Dict[str, str] = {}
try:
text = first.decode("utf-8", "replace").strip()
except Exception:
continue
if ":" not in text:
continue
k, v = text.split(":", 1)
headers[k.lower().strip()] = v.strip()
while True:
line = self._in.readline()
if not line:
return None
if line in (b"\r\n", b"\n"):
break
try:
text = line.decode("utf-8", "replace").strip()
except Exception:
continue
if ":" not in text:
continue
k, v = text.split(":", 1)
headers[k.lower().strip()] = v.strip()
if "content-length" not in headers:
return None
try:
length = int(headers["content-length"])
except ValueError:
return None
body = self._in.read(length)
if not body:
return None
self._mode = "headers"
msg = json.loads(body.decode("utf-8", "replace"))
if isinstance(msg, dict):
return msg
return None
def write_message(self, message: Dict[str, Any]) -> None:
if self._mode == "line":
payload = json.dumps(
message, ensure_ascii=False, separators=(",", ":"), default=str
).encode("utf-8")
self._out.write(payload + b"\n")
self._out.flush()
return
body = json.dumps(
message, ensure_ascii=False, separators=(",", ":"), default=str
).encode("utf-8")
header = f"Content-Length: {len(body)}\r\n\r\n".encode("utf-8")
self._out.write(header)
self._out.write(body)
self._out.flush()
def main() -> None:
tools = OracleAnswerTools()
rpc = StdioJsonRpc()
handlers: Dict[str, Callable[[Dict[str, Any]], Any]] = {
"oracle_answer": lambda a: tools.oracle_answer(**a),
}
while True:
msg = rpc.read_message()
if msg is None:
return
method = msg.get("method")
msg_id = msg.get("id")
params = msg.get("params") or {}
try:
if method == "initialize":
result = {
"protocolVersion": "2024-11-05",
"serverInfo": {"name": "oracle_answer", "version": "0.1.0"},
"capabilities": {"tools": {}},
}
rpc.write_message({"jsonrpc": "2.0", "id": msg_id, "result": result})
continue
if method == "tools/list":
rpc.write_message(
{"jsonrpc": "2.0", "id": msg_id, "result": {"tools": TOOLS}}
)
continue
if method == "tools/call":
tool_name = str(params.get("name") or "")
args = params.get("arguments") or {}
handler = handlers.get(tool_name)
if not handler:
rpc.write_message(
{
"jsonrpc": "2.0",
"id": msg_id,
"result": _mcp_text_result(
{
"ok": False,
"summary": f"Unknown tool: {tool_name}",
"data": {"known_tools": sorted(handlers.keys())},
"truncated": False,
"next_steps": ["Call tools/list"],
},
is_error=True,
),
}
)
continue
payload = asyncio.run(handler(args)) # type: ignore[arg-type]
is_error = (
not bool(payload.get("ok", True))
if isinstance(payload, dict)
else False
)
rpc.write_message(
{
"jsonrpc": "2.0",
"id": msg_id,
"result": _mcp_text_result(payload, is_error=is_error),
}
)
continue
# Ignore notifications.
if msg_id is None:
continue
rpc.write_message(
{
"jsonrpc": "2.0",
"id": msg_id,
"result": _mcp_text_result(
{"ok": False, "summary": f"Unsupported method: {method}"},
is_error=True,
),
}
)
except Exception as e: # noqa: BLE001
if msg_id is not None:
rpc.write_message(
{
"jsonrpc": "2.0",
"id": msg_id,
"result": _mcp_text_result(
{"ok": False, "summary": f"fatal error: {e}"},
is_error=True,
),
}
)
def _layer0_payload(routing_action: str, shadow: ShadowEvalResult) -> Dict[str, Any]:
if routing_action == "FAIL_CLOSED":
return {"ok": False, "summary": "Layer 0: cannot comply with this request."}
if routing_action == "HANDOFF_TO_GUARDRAILS":
reason = shadow.reason or "governance_violation"
return {
"ok": False,
"summary": f"Layer 0: governance violation detected ({reason}).",
}
if routing_action == "PROMPT_FOR_CLARIFICATION":
return {
"ok": False,
"summary": "Layer 0: request is ambiguous. Please clarify and retry.",
}
return {"ok": False, "summary": "Layer 0: unrecognized routing action; refusing."}
if __name__ == "__main__":
main()

View File

@@ -9,7 +9,11 @@ Separate from CLI/API wrapper for clean testability.
from __future__ import annotations
import asyncio
import json
import os
import urllib.error
import urllib.request
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
@@ -92,12 +96,10 @@ class OracleAnswerTool:
if self.use_local_only:
return "Local-only mode: skipping NVIDIA API call"
if not httpx:
raise ImportError("httpx not installed. Install with: pip install httpx")
headers = {
"Authorization": f"Bearer {self.api_key}",
"Accept": "application/json",
"Content-Type": "application/json",
}
payload = {
@@ -108,18 +110,45 @@ class OracleAnswerTool:
"max_tokens": 1024,
}
try:
async with httpx.AsyncClient() as client:
response = await client.post(
f"{self.NVIDIA_API_BASE}/chat/completions",
json=payload,
headers=headers,
timeout=30.0,
)
response.raise_for_status()
data = response.json()
# Prefer httpx when available; otherwise fall back to stdlib urllib to avoid extra deps.
if httpx:
try:
async with httpx.AsyncClient() as client:
response = await client.post(
f"{self.NVIDIA_API_BASE}/chat/completions",
json=payload,
headers=headers,
timeout=30.0,
)
response.raise_for_status()
data = response.json()
return data["choices"][0]["message"]["content"]
except Exception as e: # noqa: BLE001
return f"(API Error: {str(e)}) Falling back to local analysis..."
def _urllib_post() -> str:
req = urllib.request.Request(
url=f"{self.NVIDIA_API_BASE}/chat/completions",
method="POST",
headers=headers,
data=json.dumps(payload, ensure_ascii=False).encode("utf-8"),
)
try:
with urllib.request.urlopen(req, timeout=30) as resp:
raw = resp.read().decode("utf-8", "replace")
data = json.loads(raw)
return data["choices"][0]["message"]["content"]
except Exception as e:
except urllib.error.HTTPError as e:
detail = ""
try:
detail = e.read().decode("utf-8", "replace")
except Exception:
detail = str(e)
raise RuntimeError(f"HTTP {e.code}: {detail}") from e
try:
return await asyncio.to_thread(_urllib_post)
except Exception as e: # noqa: BLE001
return f"(API Error: {str(e)}) Falling back to local analysis..."
async def answer(

View File

@@ -10,22 +10,24 @@ This module provides tools to:
Export primary classes and functions:
"""
from mcp.waf_intelligence.analyzer import (
WAFRuleAnalyzer,
RuleViolation,
__version__ = "0.3.0"
from .analyzer import (
AnalysisResult,
RuleViolation,
WAFRuleAnalyzer,
)
from mcp.waf_intelligence.generator import (
WAFRuleGenerator,
GeneratedRule,
)
from mcp.waf_intelligence.compliance import (
from .compliance import (
ComplianceMapper,
FrameworkMapping,
)
from mcp.waf_intelligence.orchestrator import (
WAFIntelligence,
from .generator import (
GeneratedRule,
WAFRuleGenerator,
)
from .orchestrator import (
WAFInsight,
WAFIntelligence,
)
__all__ = [

View File

@@ -10,6 +10,7 @@ from typing import Any, Dict, List
from layer0 import layer0_entry
from layer0.shadow_classifier import ShadowEvalResult
from . import __version__ as WAF_INTEL_VERSION
from .orchestrator import WAFInsight, WAFIntelligence
@@ -56,11 +57,18 @@ def run_cli(argv: List[str] | None = None) -> int:
action="store_true",
help="Exit with non-zero code if any error-severity violations are found.",
)
parser.add_argument(
"--version",
action="version",
version=f"%(prog)s {WAF_INTEL_VERSION}",
)
args = parser.parse_args(argv)
# Layer 0: pre-boot Shadow Eval gate.
routing_action, shadow = layer0_entry(f"waf_intel_cli file={args.file} limit={args.limit}")
routing_action, shadow = layer0_entry(
f"waf_intel_cli file={args.file} limit={args.limit}"
)
if routing_action != "HANDOFF_TO_LAYER1":
_render_layer0_block(routing_action, shadow)
return 1
@@ -90,7 +98,9 @@ def run_cli(argv: List[str] | None = None) -> int:
print(f"\nWAF Intelligence Report for: {path}\n{'-' * 72}")
if not insights:
print("No high-severity, high-confidence issues detected based on current heuristics.")
print(
"No high-severity, high-confidence issues detected based on current heuristics."
)
return 0
for idx, insight in enumerate(insights, start=1):
@@ -119,7 +129,9 @@ def run_cli(argv: List[str] | None = None) -> int:
if insight.mappings:
print("\nCompliance Mapping:")
for mapping in insight.mappings:
print(f" - {mapping.framework} {mapping.control_id}: {mapping.description}")
print(
f" - {mapping.framework} {mapping.control_id}: {mapping.description}"
)
print()

View File

@@ -1,9 +1,16 @@
from __future__ import annotations
import re
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Optional
MANAGED_WAF_RULESET_IDS = (
# Cloudflare managed WAF ruleset IDs (last updated 2025-12-18).
"efb7b8c949ac4650a09736fc376e9aee", # Cloudflare Managed Ruleset
"4814384a9e5d4991b9815dcfc25d2f1f", # OWASP Core Ruleset
)
@dataclass
class RuleViolation:
@@ -57,6 +64,20 @@ class WAFRuleAnalyzer:
Analyze Cloudflare WAF rules from Terraform with a quality-first posture.
"""
def _has_managed_waf_rules(self, text: str) -> bool:
text_lower = text.lower()
if "managed_rules" in text_lower:
return True
if re.search(r'phase\s*=\s*"http_request_firewall_managed"', text_lower):
return True
if "cf.waf" in text_lower:
return True
return any(ruleset_id in text_lower for ruleset_id in MANAGED_WAF_RULESET_IDS)
def analyze_file(
self,
path: str | Path,
@@ -70,7 +91,7 @@ class WAFRuleAnalyzer:
violations: List[RuleViolation] = []
# Example heuristic: no managed rules present
if "managed_rules" not in text:
if not self._has_managed_waf_rules(text):
violations.append(
RuleViolation(
rule_id=None,
@@ -102,7 +123,7 @@ class WAFRuleAnalyzer:
violations=violations,
metadata={
"file_size": path.stat().st_size,
"heuristics_version": "0.2.0",
"heuristics_version": "0.3.0",
},
)
@@ -125,7 +146,7 @@ class WAFRuleAnalyzer:
tmp_path = Path(source_name)
violations: List[RuleViolation] = []
if "managed_rules" not in text:
if not self._has_managed_waf_rules(text):
violations.append(
RuleViolation(
rule_id=None,
@@ -141,7 +162,7 @@ class WAFRuleAnalyzer:
result = AnalysisResult(
source=str(tmp_path),
violations=violations,
metadata={"heuristics_version": "0.2.0"},
metadata={"heuristics_version": "0.3.0"},
)
result.violations = result.top_violations(
@@ -161,27 +182,37 @@ class WAFRuleAnalyzer:
) -> AnalysisResult:
"""
Enhanced analysis using threat intelligence data.
Args:
path: WAF config file path
threat_indicators: List of ThreatIndicator objects from threat_intel module
min_severity: Minimum severity to include
min_confidence: Minimum confidence threshold
Returns:
AnalysisResult with violations informed by threat intel
"""
# Start with base analysis
base_result = self.analyze_file(path, min_severity=min_severity, min_confidence=min_confidence)
base_result = self.analyze_file(
path, min_severity=min_severity, min_confidence=min_confidence
)
path = Path(path)
text = path.read_text(encoding="utf-8")
text_lower = text.lower()
# Check if threat indicators are addressed by existing rules
critical_ips = [i for i in threat_indicators if i.indicator_type == "ip" and i.severity in ("critical", "high")]
critical_patterns = [i for i in threat_indicators if i.indicator_type == "pattern" and i.severity in ("critical", "high")]
critical_ips = [
i
for i in threat_indicators
if i.indicator_type == "ip" and i.severity in ("critical", "high")
]
critical_patterns = [
i
for i in threat_indicators
if i.indicator_type == "pattern" and i.severity in ("critical", "high")
]
# Check for IP blocking coverage
if critical_ips:
ip_block_present = "ip.src" in text_lower or "cf.client.ip" in text_lower
@@ -197,14 +228,14 @@ class WAFRuleAnalyzer:
hint=f"Add IP blocking rules for identified threat actors. Sample IPs: {', '.join(i.value for i in critical_ips[:3])}",
)
)
# Check for pattern-based attack coverage
attack_types_seen = set()
for ind in critical_patterns:
for tag in ind.tags:
if tag in ("sqli", "xss", "rce", "path_traversal"):
attack_types_seen.add(tag)
# Check managed ruleset coverage
for attack_type in attack_types_seen:
if attack_type not in text_lower and f'"{attack_type}"' not in text_lower:
@@ -219,13 +250,12 @@ class WAFRuleAnalyzer:
hint=f"Enable Cloudflare managed rules for {attack_type.upper()} protection.",
)
)
# Update metadata with threat intel stats
base_result.metadata["threat_intel"] = {
"critical_ips": len(critical_ips),
"critical_patterns": len(critical_patterns),
"attack_types_seen": list(attack_types_seen),
}
return base_result
return base_result

View File

@@ -0,0 +1,632 @@
from __future__ import annotations
import glob
import json
import os
import sys
from dataclasses import asdict
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional
from cloudflare.layer0 import layer0_entry
from cloudflare.layer0.shadow_classifier import ShadowEvalResult
from .orchestrator import ThreatAssessment, WAFInsight, WAFIntelligence
MAX_BYTES_DEFAULT = 32_000
def _cloudflare_root() -> Path:
# mcp_server.py -> waf_intelligence -> mcp -> cloudflare
return Path(__file__).resolve().parents[2]
def _max_bytes() -> int:
raw = (os.getenv("VM_MCP_MAX_BYTES") or "").strip()
if not raw:
return MAX_BYTES_DEFAULT
try:
return max(4_096, int(raw))
except ValueError:
return MAX_BYTES_DEFAULT
def _redact(obj: Any) -> Any:
sensitive_keys = ("token", "secret", "password", "private", "key", "certificate")
if isinstance(obj, dict):
out: Dict[str, Any] = {}
for k, v in obj.items():
if any(s in str(k).lower() for s in sensitive_keys):
out[k] = "<REDACTED>"
else:
out[k] = _redact(v)
return out
if isinstance(obj, list):
return [_redact(v) for v in obj]
if isinstance(obj, str):
if obj.startswith("ghp_") or obj.startswith("github_pat_"):
return "<REDACTED>"
return obj
return obj
def _safe_json(payload: Dict[str, Any]) -> str:
payload = _redact(payload)
raw = json.dumps(payload, ensure_ascii=False, separators=(",", ":"), default=str)
if len(raw.encode("utf-8")) <= _max_bytes():
return json.dumps(payload, ensure_ascii=False, indent=2, default=str)
truncated = {
"ok": payload.get("ok", True),
"truncated": True,
"summary": payload.get("summary", "Response exceeded max size; truncated."),
"next_steps": payload.get(
"next_steps",
[
"request fewer files/insights (limit=...)",
"use higher min_severity to reduce output",
],
),
}
return json.dumps(truncated, ensure_ascii=False, indent=2, default=str)
def _mcp_text_result(
payload: Dict[str, Any], *, is_error: bool = False
) -> Dict[str, Any]:
result: Dict[str, Any] = {
"content": [{"type": "text", "text": _safe_json(payload)}]
}
if is_error:
result["isError"] = True
return result
def _insight_to_dict(insight: WAFInsight) -> Dict[str, Any]:
return asdict(insight)
def _assessment_to_dict(assessment: ThreatAssessment) -> Dict[str, Any]:
violations = []
if assessment.analysis_result and getattr(
assessment.analysis_result, "violations", None
):
violations = list(assessment.analysis_result.violations)
severity_counts = {"error": 0, "warning": 0, "info": 0}
for v in violations:
sev = getattr(v, "severity", "info")
if sev in severity_counts:
severity_counts[sev] += 1
return {
"risk_score": assessment.risk_score,
"risk_level": assessment.risk_level,
"classification_summary": assessment.classification_summary,
"recommended_actions": assessment.recommended_actions,
"analysis": {
"has_config_analysis": assessment.analysis_result is not None,
"violations_total": len(violations),
"violations_by_severity": severity_counts,
},
"has_threat_intel": assessment.threat_report is not None,
"generated_at": str(assessment.generated_at),
}
TOOLS: List[Dict[str, Any]] = [
{
"name": "waf_capabilities",
"description": "List available WAF Intelligence capabilities.",
"inputSchema": {"type": "object", "properties": {}},
},
{
"name": "analyze_waf",
"description": "Analyze Terraform WAF file(s) and return curated insights (legacy alias for waf_analyze).",
"inputSchema": {
"type": "object",
"properties": {
"file": {
"type": "string",
"description": "Single file path to analyze.",
},
"files": {
"type": "array",
"items": {"type": "string"},
"description": "List of file paths or glob patterns to analyze.",
},
"limit": {
"type": "integer",
"default": 3,
"description": "Max insights per file.",
},
"severity_threshold": {
"type": "string",
"enum": ["info", "warning", "error"],
"default": "warning",
"description": "Minimum severity to include (alias for min_severity).",
},
},
},
},
{
"name": "waf_analyze",
"description": "Analyze Terraform WAF file(s) and return curated insights (requires file or files).",
"inputSchema": {
"type": "object",
"properties": {
"file": {
"type": "string",
"description": "Single file path to analyze.",
},
"files": {
"type": "array",
"items": {"type": "string"},
"description": "List of file paths or glob patterns to analyze.",
},
"limit": {
"type": "integer",
"default": 3,
"description": "Max insights per file.",
},
"min_severity": {
"type": "string",
"enum": ["info", "warning", "error"],
"default": "warning",
"description": "Minimum severity to include.",
},
},
},
},
{
"name": "waf_assess",
"description": "Run a broader assessment (optionally includes threat intel collection).",
"inputSchema": {
"type": "object",
"properties": {
"waf_config_path": {
"type": "string",
"description": "Path to Terraform WAF config (default: terraform/waf.tf).",
},
"include_threat_intel": {
"type": "boolean",
"default": False,
"description": "If true, attempt to collect threat intel (may require network and credentials).",
},
},
},
},
{
"name": "waf_generate_gitops_proposals",
"description": "Generate GitOps-ready rule proposals (best-effort; requires threat intel to produce output).",
"inputSchema": {
"type": "object",
"properties": {
"waf_config_path": {
"type": "string",
"description": "Path to Terraform WAF config (default: terraform/waf.tf).",
},
"include_threat_intel": {
"type": "boolean",
"default": True,
"description": "Attempt to collect threat intel before proposing rules.",
},
"max_proposals": {
"type": "integer",
"default": 5,
"description": "Maximum proposals to generate.",
},
},
},
},
]
class WafIntelligenceTools:
def __init__(self) -> None:
self.workspace_root = _cloudflare_root()
self.repo_root = self.workspace_root.parent
self.waf = WAFIntelligence(workspace_path=str(self.workspace_root))
def _resolve_path(self, raw: str) -> Path:
path = Path(raw)
if path.is_absolute():
return path
candidates = [
Path.cwd() / path,
self.workspace_root / path,
self.repo_root / path,
]
for candidate in candidates:
if candidate.exists():
return candidate
return self.workspace_root / path
def waf_capabilities(self) -> Dict[str, Any]:
return {
"ok": True,
"summary": "WAF Intelligence capabilities.",
"data": {"capabilities": self.waf.capabilities},
"truncated": False,
"next_steps": [
"Call waf_analyze(file=..., limit=...) to analyze config.",
"Call waf_assess(include_threat_intel=true) for a broader assessment.",
],
}
def waf_analyze(
self,
*,
file: Optional[str] = None,
files: Optional[List[str]] = None,
limit: int = 3,
min_severity: str = "warning",
) -> Dict[str, Any]:
paths: List[str] = []
if files:
for pattern in files:
paths.extend(glob.glob(pattern))
if file:
paths.append(file)
seen = set()
unique_paths: List[str] = []
for p in paths:
if p not in seen:
seen.add(p)
unique_paths.append(p)
if not unique_paths:
return {
"ok": False,
"summary": "Provide 'file' or 'files' to analyze.",
"truncated": False,
"next_steps": ["Call waf_analyze(file='terraform/waf.tf')"],
}
results: List[Dict[str, Any]] = []
for p in unique_paths:
path = self._resolve_path(p)
if not path.exists():
results.append(
{
"file": str(path),
"ok": False,
"summary": "File not found.",
}
)
continue
insights = self.waf.analyze_and_recommend(
str(path),
limit=limit,
min_severity=min_severity,
)
results.append(
{
"file": str(path),
"ok": True,
"insights": [_insight_to_dict(i) for i in insights],
}
)
ok = all(r.get("ok") for r in results)
return {
"ok": ok,
"summary": f"Analyzed {len(results)} file(s).",
"data": {"results": results},
"truncated": False,
"next_steps": [
"Raise/lower min_severity or limit to tune output size.",
],
}
def waf_assess(
self,
*,
waf_config_path: Optional[str] = None,
include_threat_intel: bool = False,
) -> Dict[str, Any]:
waf_config_path_resolved = (
str(self._resolve_path(waf_config_path)) if waf_config_path else None
)
assessment = self.waf.full_assessment(
waf_config_path=waf_config_path_resolved,
include_threat_intel=include_threat_intel,
)
return {
"ok": True,
"summary": "WAF assessment complete.",
"data": _assessment_to_dict(assessment),
"truncated": False,
"next_steps": [
"Call waf_generate_gitops_proposals(...) to draft Terraform rule proposals (best-effort).",
],
}
def waf_generate_gitops_proposals(
self,
*,
waf_config_path: Optional[str] = None,
include_threat_intel: bool = True,
max_proposals: int = 5,
) -> Dict[str, Any]:
waf_config_path_resolved = (
str(self._resolve_path(waf_config_path)) if waf_config_path else None
)
assessment = self.waf.full_assessment(
waf_config_path=waf_config_path_resolved,
include_threat_intel=include_threat_intel,
)
proposals = self.waf.generate_gitops_proposals(
threat_report=assessment.threat_report,
max_proposals=max_proposals,
)
return {
"ok": True,
"summary": f"Generated {len(proposals)} proposal(s).",
"data": {
"assessment": _assessment_to_dict(assessment),
"proposals": proposals,
},
"truncated": False,
"next_steps": [
"If proposals are empty, enable threat intel and ensure required credentials/log sources exist.",
],
}
class StdioJsonRpc:
def __init__(self) -> None:
self._in = sys.stdin.buffer
self._out = sys.stdout.buffer
self._mode: str | None = None # "headers" | "line"
def read_message(self) -> Optional[Dict[str, Any]]:
while True:
if self._mode == "line":
line = self._in.readline()
if not line:
return None
raw = line.decode("utf-8", "replace").strip()
if not raw:
continue
try:
msg = json.loads(raw)
except Exception:
continue
if isinstance(msg, dict):
return msg
continue
first = self._in.readline()
if not first:
return None
if first in (b"\r\n", b"\n"):
continue
# Auto-detect newline-delimited JSON framing.
if self._mode is None and first.lstrip().startswith(b"{"):
try:
msg = json.loads(first.decode("utf-8", "replace"))
except Exception:
msg = None
if isinstance(msg, dict):
self._mode = "line"
return msg
headers: Dict[str, str] = {}
try:
text = first.decode("utf-8", "replace").strip()
except Exception:
continue
if ":" not in text:
continue
k, v = text.split(":", 1)
headers[k.lower().strip()] = v.strip()
while True:
line = self._in.readline()
if not line:
return None
if line in (b"\r\n", b"\n"):
break
try:
text = line.decode("utf-8", "replace").strip()
except Exception:
continue
if ":" not in text:
continue
k, v = text.split(":", 1)
headers[k.lower().strip()] = v.strip()
if "content-length" not in headers:
return None
try:
length = int(headers["content-length"])
except ValueError:
return None
body = self._in.read(length)
if not body:
return None
self._mode = "headers"
msg = json.loads(body.decode("utf-8", "replace"))
if isinstance(msg, dict):
return msg
return None
def write_message(self, message: Dict[str, Any]) -> None:
if self._mode == "line":
payload = json.dumps(
message, ensure_ascii=False, separators=(",", ":"), default=str
).encode("utf-8")
self._out.write(payload + b"\n")
self._out.flush()
return
body = json.dumps(
message, ensure_ascii=False, separators=(",", ":"), default=str
).encode("utf-8")
header = f"Content-Length: {len(body)}\r\n\r\n".encode("utf-8")
self._out.write(header)
self._out.write(body)
self._out.flush()
def main() -> None:
tools = WafIntelligenceTools()
rpc = StdioJsonRpc()
handlers: Dict[str, Callable[[Dict[str, Any]], Dict[str, Any]]] = {
"waf_capabilities": lambda a: tools.waf_capabilities(),
"analyze_waf": lambda a: tools.waf_analyze(
file=a.get("file"),
files=a.get("files"),
limit=int(a.get("limit", 3)),
min_severity=str(a.get("severity_threshold", "warning")),
),
"waf_analyze": lambda a: tools.waf_analyze(**a),
"waf_assess": lambda a: tools.waf_assess(**a),
"waf_generate_gitops_proposals": lambda a: tools.waf_generate_gitops_proposals(
**a
),
}
while True:
msg = rpc.read_message()
if msg is None:
return
method = msg.get("method")
msg_id = msg.get("id")
params = msg.get("params") or {}
try:
if method == "initialize":
result = {
"protocolVersion": "2024-11-05",
"serverInfo": {"name": "waf_intelligence", "version": "0.1.0"},
"capabilities": {"tools": {}},
}
rpc.write_message({"jsonrpc": "2.0", "id": msg_id, "result": result})
continue
if method == "tools/list":
rpc.write_message(
{"jsonrpc": "2.0", "id": msg_id, "result": {"tools": TOOLS}}
)
continue
if method == "tools/call":
tool_name = str(params.get("name") or "")
args = params.get("arguments") or {}
routing_action, shadow = layer0_entry(
_shadow_query_repr(tool_name, args)
)
if routing_action != "HANDOFF_TO_LAYER1":
rpc.write_message(
{
"jsonrpc": "2.0",
"id": msg_id,
"result": _mcp_text_result(
_layer0_payload(routing_action, shadow), is_error=True
),
}
)
continue
handler = handlers.get(tool_name)
if not handler:
rpc.write_message(
{
"jsonrpc": "2.0",
"id": msg_id,
"result": _mcp_text_result(
{
"ok": False,
"summary": f"Unknown tool: {tool_name}",
"data": {"known_tools": sorted(handlers.keys())},
"truncated": False,
"next_steps": ["Call tools/list"],
},
is_error=True,
),
}
)
continue
payload = handler(args)
is_error = (
not bool(payload.get("ok", True))
if isinstance(payload, dict)
else False
)
rpc.write_message(
{
"jsonrpc": "2.0",
"id": msg_id,
"result": _mcp_text_result(payload, is_error=is_error),
}
)
continue
# Ignore notifications.
if msg_id is None:
continue
rpc.write_message(
{
"jsonrpc": "2.0",
"id": msg_id,
"result": _mcp_text_result(
{"ok": False, "summary": f"Unsupported method: {method}"},
is_error=True,
),
}
)
except Exception as e: # noqa: BLE001
if msg_id is not None:
rpc.write_message(
{
"jsonrpc": "2.0",
"id": msg_id,
"result": _mcp_text_result(
{"ok": False, "summary": f"fatal error: {e}"},
is_error=True,
),
}
)
def _shadow_query_repr(tool_name: str, tool_args: Dict[str, Any]) -> str:
if tool_name == "waf_capabilities":
return "List WAF Intelligence capabilities."
try:
return f"{tool_name}: {json.dumps(tool_args, sort_keys=True, default=str)}"
except Exception:
return f"{tool_name}: {str(tool_args)}"
def _layer0_payload(routing_action: str, shadow: ShadowEvalResult) -> Dict[str, Any]:
if routing_action == "FAIL_CLOSED":
return {"ok": False, "summary": "Layer 0: cannot comply with this request."}
if routing_action == "HANDOFF_TO_GUARDRAILS":
reason = shadow.reason or "governance_violation"
return {
"ok": False,
"summary": f"Layer 0: governance violation detected ({reason}).",
}
if routing_action == "PROMPT_FOR_CLARIFICATION":
return {
"ok": False,
"summary": "Layer 0: request is ambiguous. Please clarify and retry.",
}
return {"ok": False, "summary": "Layer 0: unrecognized routing action; refusing."}
if __name__ == "__main__":
main()

View File

@@ -6,27 +6,26 @@ from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional
from mcp.waf_intelligence.analyzer import AnalysisResult, RuleViolation, WAFRuleAnalyzer
from mcp.waf_intelligence.compliance import ComplianceMapper, FrameworkMapping
from mcp.waf_intelligence.generator import GeneratedRule, WAFRuleGenerator
from .analyzer import AnalysisResult, RuleViolation, WAFRuleAnalyzer
from .compliance import ComplianceMapper, FrameworkMapping
from .generator import GeneratedRule, WAFRuleGenerator
# Optional advanced modules (Phase 7)
try:
from mcp.waf_intelligence.threat_intel import (
from .threat_intel import (
ThreatIntelCollector,
ThreatIntelReport,
ThreatIndicator,
)
_HAS_THREAT_INTEL = True
except ImportError:
_HAS_THREAT_INTEL = False
ThreatIntelCollector = None
try:
from mcp.waf_intelligence.classifier import (
ThreatClassifier,
ClassificationResult,
)
from .classifier import ThreatClassifier
_HAS_CLASSIFIER = True
except ImportError:
_HAS_CLASSIFIER = False
@@ -45,14 +44,14 @@ class WAFInsight:
@dataclass
class ThreatAssessment:
"""Phase 7: Comprehensive threat assessment result."""
analysis_result: Optional[AnalysisResult] = None
threat_report: Optional[Any] = None # ThreatIntelReport when available
classification_summary: Dict[str, int] = field(default_factory=dict)
risk_score: float = 0.0
recommended_actions: List[str] = field(default_factory=list)
generated_at: datetime = field(default_factory=datetime.utcnow)
@property
def risk_level(self) -> str:
if self.risk_score >= 0.8:
@@ -81,22 +80,22 @@ class WAFIntelligence:
enable_ml_classifier: bool = True,
) -> None:
self.workspace = Path(workspace_path) if workspace_path else Path.cwd()
# Core components
self.analyzer = WAFRuleAnalyzer()
self.generator = WAFRuleGenerator()
self.mapper = ComplianceMapper()
# Phase 7 components (optional)
self.threat_intel: Optional[Any] = None
self.classifier: Optional[Any] = None
if enable_threat_intel and _HAS_THREAT_INTEL:
try:
self.threat_intel = ThreatIntelCollector()
except Exception:
pass
if enable_ml_classifier and _HAS_CLASSIFIER:
try:
self.classifier = ThreatClassifier()
@@ -149,24 +148,24 @@ class WAFIntelligence:
) -> Optional[Any]:
"""
Collect threat intelligence from logs and external feeds.
Args:
log_paths: Paths to Cloudflare log files
max_indicators: Maximum indicators to collect
Returns:
ThreatIntelReport or None if unavailable
"""
if not self.threat_intel:
return None
# Default log paths
if log_paths is None:
log_paths = [
str(self.workspace / "logs"),
"/var/log/cloudflare",
]
return self.threat_intel.collect(
log_paths=log_paths,
max_indicators=max_indicators,
@@ -175,16 +174,16 @@ class WAFIntelligence:
def classify_threat(self, payload: str) -> Optional[Any]:
"""
Classify a payload using ML classifier.
Args:
payload: Request payload to classify
Returns:
ClassificationResult or None
"""
if not self.classifier:
return None
return self.classifier.classify(payload)
def full_assessment(
@@ -195,51 +194,52 @@ class WAFIntelligence:
) -> ThreatAssessment:
"""
Phase 7: Perform comprehensive threat assessment.
Combines:
- WAF configuration analysis
- Threat intelligence collection
- ML classification summary
- Risk scoring
Args:
waf_config_path: Path to WAF Terraform file
log_paths: Paths to log files
include_threat_intel: Whether to collect threat intel
Returns:
ThreatAssessment with full analysis results
"""
assessment = ThreatAssessment()
risk_factors: List[float] = []
recommendations: List[str] = []
# 1. Analyze WAF configuration
if waf_config_path is None:
waf_config_path = str(self.workspace / "terraform" / "waf.tf")
if Path(waf_config_path).exists():
assessment.analysis_result = self.analyzer.analyze_file(
waf_config_path,
min_severity="info",
)
# Calculate risk from violations
severity_weights = {"error": 0.8, "warning": 0.5, "info": 0.2}
for violation in assessment.analysis_result.violations:
weight = severity_weights.get(violation.severity, 0.3)
risk_factors.append(weight)
# Generate recommendations
critical_count = sum(
1 for v in assessment.analysis_result.violations
1
for v in assessment.analysis_result.violations
if v.severity == "error"
)
if critical_count > 0:
recommendations.append(
f"🔴 Fix {critical_count} critical WAF configuration issues"
)
# 2. Collect threat intelligence
if include_threat_intel and self.threat_intel:
try:
@@ -247,52 +247,55 @@ class WAFIntelligence:
log_paths=log_paths,
max_indicators=50,
)
if assessment.threat_report:
indicators = assessment.threat_report.indicators
# Count by severity
severity_counts = {"critical": 0, "high": 0, "medium": 0, "low": 0}
for ind in indicators:
sev = getattr(ind, "severity", "low")
severity_counts[sev] = severity_counts.get(sev, 0) + 1
# Add to classification summary
assessment.classification_summary["threat_indicators"] = len(indicators)
assessment.classification_summary["threat_indicators"] = len(
indicators
)
assessment.classification_summary.update(severity_counts)
# Calculate threat intel risk
if indicators:
critical_ratio = severity_counts["critical"] / len(indicators)
high_ratio = severity_counts["high"] / len(indicators)
risk_factors.append(critical_ratio * 0.9 + high_ratio * 0.7)
if severity_counts["critical"] > 0:
recommendations.append(
f"🚨 Block {severity_counts['critical']} critical threat IPs immediately"
)
except Exception:
pass
# 3. ML classification summary (from any collected data)
if self.classifier and assessment.threat_report:
try:
attack_types = {"sqli": 0, "xss": 0, "rce": 0, "clean": 0, "unknown": 0}
indicators = assessment.threat_report.indicators
pattern_indicators = [
i for i in indicators
i
for i in indicators
if getattr(i, "indicator_type", "") == "pattern"
]
for ind in pattern_indicators[:20]: # Sample first 20
result = self.classifier.classify(ind.value)
if result:
label = result.label
attack_types[label] = attack_types.get(label, 0) + 1
assessment.classification_summary["ml_classifications"] = attack_types
# Add ML risk factor
dangerous = attack_types.get("sqli", 0) + attack_types.get("rce", 0)
if dangerous > 5:
@@ -302,15 +305,17 @@ class WAFIntelligence:
)
except Exception:
pass
# 4. Calculate final risk score
if risk_factors:
assessment.risk_score = min(1.0, sum(risk_factors) / max(len(risk_factors), 1))
assessment.risk_score = min(
1.0, sum(risk_factors) / max(len(risk_factors), 1)
)
else:
assessment.risk_score = 0.3 # Baseline risk
assessment.recommended_actions = recommendations
return assessment
def generate_gitops_proposals(
@@ -320,42 +325,44 @@ class WAFIntelligence:
) -> List[Dict[str, Any]]:
"""
Generate GitOps-ready rule proposals.
Args:
threat_report: ThreatIntelReport to use
max_proposals: Maximum proposals to generate
Returns:
List of proposal dicts ready for MR creation
"""
proposals: List[Dict[str, Any]] = []
if not threat_report:
return proposals
try:
# Import proposer dynamically
from gitops.waf_rule_proposer import WAFRuleProposer
proposer = WAFRuleProposer(workspace_path=str(self.workspace))
batch = proposer.generate_proposals(
threat_report=threat_report,
max_proposals=max_proposals,
)
for proposal in batch.proposals:
proposals.append({
"name": proposal.rule_name,
"type": proposal.rule_type,
"severity": proposal.severity,
"confidence": proposal.confidence,
"terraform": proposal.terraform_code,
"justification": proposal.justification,
"auto_deploy": proposal.auto_deploy_eligible,
})
proposals.append(
{
"name": proposal.rule_name,
"type": proposal.rule_type,
"severity": proposal.severity,
"confidence": proposal.confidence,
"terraform": proposal.terraform_code,
"justification": proposal.justification,
"auto_deploy": proposal.auto_deploy_eligible,
}
)
except ImportError:
pass
return proposals
@property

326
mcp/waf_intelligence/server.py Executable file → Normal file
View File

@@ -1,326 +1,14 @@
#!/usr/bin/env python3
"""
WAF Intelligence MCP Server for VS Code Copilot.
from __future__ import annotations
This implements the Model Context Protocol (MCP) stdio interface
so VS Code can communicate with your WAF Intelligence system.
"""
Deprecated entrypoint kept for older editor configs.
Use `python3 -m mcp.waf_intelligence.mcp_server` (or `waf_intel_mcp.py`) instead.
"""
import json
import sys
from typing import Any
# Add parent to path for imports
sys.path.insert(0, '/Users/sovereign/Desktop/CLOUDFLARE')
from mcp.waf_intelligence.orchestrator import WAFIntelligence
from mcp.waf_intelligence.analyzer import WAFRuleAnalyzer
from layer0 import layer0_entry
from layer0.shadow_classifier import ShadowEvalResult
class WAFIntelligenceMCPServer:
"""MCP Server wrapper for WAF Intelligence."""
def __init__(self):
self.waf = WAFIntelligence()
self.analyzer = WAFRuleAnalyzer()
def get_capabilities(self) -> dict:
"""Return server capabilities."""
return {
"tools": [
{
"name": "waf_analyze",
"description": "Analyze WAF logs and detect attack patterns",
"inputSchema": {
"type": "object",
"properties": {
"log_file": {
"type": "string",
"description": "Path to WAF log file (optional)"
},
"zone_id": {
"type": "string",
"description": "Cloudflare zone ID (optional)"
}
}
}
},
{
"name": "waf_assess",
"description": "Run full security assessment with threat intel and ML classification",
"inputSchema": {
"type": "object",
"properties": {
"zone_id": {
"type": "string",
"description": "Cloudflare zone ID"
}
},
"required": ["zone_id"]
}
},
{
"name": "waf_generate_rules",
"description": "Generate Terraform WAF rules from threat intelligence",
"inputSchema": {
"type": "object",
"properties": {
"zone_id": {
"type": "string",
"description": "Cloudflare zone ID"
},
"min_confidence": {
"type": "number",
"description": "Minimum confidence threshold (0-1)",
"default": 0.7
}
},
"required": ["zone_id"]
}
},
{
"name": "waf_capabilities",
"description": "List available WAF Intelligence capabilities",
"inputSchema": {
"type": "object",
"properties": {}
}
}
]
}
def handle_tool_call(self, name: str, arguments: dict) -> dict:
"""Handle a tool invocation."""
try:
if name == "waf_capabilities":
return {
"content": [
{
"type": "text",
"text": json.dumps({
"capabilities": self.waf.capabilities,
"status": "operational"
}, indent=2)
}
]
}
elif name == "waf_analyze":
log_file = arguments.get("log_file")
zone_id = arguments.get("zone_id")
if log_file:
result = self.analyzer.analyze_log_file(log_file)
else:
result = {
"message": "No log file provided. Use zone_id for live analysis.",
"capabilities": self.waf.capabilities
}
return {
"content": [
{"type": "text", "text": json.dumps(result, indent=2, default=str)}
]
}
elif name == "waf_assess":
zone_id = arguments.get("zone_id")
# full_assessment uses workspace paths, not zone_id
assessment = self.waf.full_assessment(
include_threat_intel=True
)
# Build result from ThreatAssessment dataclass
result = {
"zone_id": zone_id,
"risk_score": assessment.risk_score,
"risk_level": assessment.risk_level,
"classification_summary": assessment.classification_summary,
"recommended_actions": assessment.recommended_actions[:10], # Top 10
"has_analysis": assessment.analysis_result is not None,
"has_threat_intel": assessment.threat_report is not None,
"generated_at": str(assessment.generated_at)
}
return {
"content": [
{"type": "text", "text": json.dumps(result, indent=2, default=str)}
]
}
elif name == "waf_generate_rules":
zone_id = arguments.get("zone_id")
min_confidence = arguments.get("min_confidence", 0.7)
# Generate proposals (doesn't use zone_id directly)
proposals = self.waf.generate_gitops_proposals(
max_proposals=5
)
result = {
"zone_id": zone_id,
"min_confidence": min_confidence,
"proposals_count": len(proposals),
"proposals": proposals
}
return {
"content": [
{"type": "text", "text": json.dumps(result, indent=2, default=str) if proposals else "No rules generated (no threat data available)"}
]
}
else:
return {
"content": [
{"type": "text", "text": f"Unknown tool: {name}"}
],
"isError": True
}
except Exception as e:
return {
"content": [
{"type": "text", "text": f"Error: {str(e)}"}
],
"isError": True
}
def run(self):
"""Run the MCP server (stdio mode)."""
# Send server info
server_info = {
"jsonrpc": "2.0",
"method": "initialized",
"params": {
"serverInfo": {
"name": "waf-intelligence",
"version": "1.0.0"
},
"capabilities": self.get_capabilities()
}
}
# Main loop - read JSON-RPC messages from stdin
for line in sys.stdin:
try:
message = json.loads(line.strip())
if message.get("method") == "initialize":
response = {
"jsonrpc": "2.0",
"id": message.get("id"),
"result": {
"protocolVersion": "2024-11-05",
"serverInfo": {
"name": "waf-intelligence",
"version": "1.0.0"
},
"capabilities": {
"tools": {}
}
}
}
print(json.dumps(response), flush=True)
elif message.get("method") == "tools/list":
response = {
"jsonrpc": "2.0",
"id": message.get("id"),
"result": self.get_capabilities()
}
print(json.dumps(response), flush=True)
elif message.get("method") == "tools/call":
params = message.get("params", {})
tool_name = params.get("name")
tool_args = params.get("arguments", {})
# Layer 0: pre-boot Shadow Eval gate before handling tool calls.
routing_action, shadow = layer0_entry(_shadow_query_repr(tool_name, tool_args))
if routing_action != "HANDOFF_TO_LAYER1":
response = _layer0_mcp_response(routing_action, shadow, message.get("id"))
print(json.dumps(response), flush=True)
continue
result = self.handle_tool_call(tool_name, tool_args)
response = {
"jsonrpc": "2.0",
"id": message.get("id"),
"result": result
}
print(json.dumps(response), flush=True)
elif message.get("method") == "notifications/initialized":
# Client acknowledged initialization
pass
else:
# Unknown method
response = {
"jsonrpc": "2.0",
"id": message.get("id"),
"error": {
"code": -32601,
"message": f"Method not found: {message.get('method')}"
}
}
print(json.dumps(response), flush=True)
except json.JSONDecodeError:
continue
except Exception as e:
error_response = {
"jsonrpc": "2.0",
"id": None,
"error": {
"code": -32603,
"message": str(e)
}
}
print(json.dumps(error_response), flush=True)
from .mcp_server import main
if __name__ == "__main__":
server = WAFIntelligenceMCPServer()
server.run()
main()
def _shadow_query_repr(tool_name: str, tool_args: dict) -> str:
"""Build a textual representation of the tool call for Layer 0 classification."""
try:
return f"{tool_name}: {json.dumps(tool_args, sort_keys=True)}"
except TypeError:
return f"{tool_name}: {str(tool_args)}"
def _layer0_mcp_response(routing_action: str, shadow: ShadowEvalResult, msg_id: Any) -> dict:
"""
Map Layer 0 outcomes to MCP responses.
Catastrophic/forbidden/ambiguous short-circuit with minimal disclosure.
"""
base = {"jsonrpc": "2.0", "id": msg_id}
if routing_action == "FAIL_CLOSED":
base["error"] = {"code": -32000, "message": "Layer 0: cannot comply with this request."}
return base
if routing_action == "HANDOFF_TO_GUARDRAILS":
reason = shadow.reason or "governance_violation"
base["error"] = {
"code": -32001,
"message": f"Layer 0: governance violation detected ({reason}).",
}
return base
if routing_action == "PROMPT_FOR_CLARIFICATION":
base["error"] = {
"code": -32002,
"message": "Layer 0: request is ambiguous. Please clarify and retry.",
}
return base
base["error"] = {"code": -32099, "message": "Layer 0: unrecognized routing action; refusing."}
return base