chore: pre-migration snapshot
Layer0, MCP servers, Terraform consolidation
This commit is contained in:
@@ -3,4 +3,6 @@ MCP tools for the CLOUDFLARE workspace.
|
||||
|
||||
Currently:
|
||||
- oracle_answer: compliance / security oracle
|
||||
- cloudflare_safe: summary-first Cloudflare state + tunnel helpers
|
||||
- akash_docs: Akash docs fetch/search + SDL template helper
|
||||
"""
|
||||
|
||||
10
mcp/akash_docs/__init__.py
Normal file
10
mcp/akash_docs/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
"""
|
||||
Akash docs + deployment helpers exposed as an MCP server.
|
||||
|
||||
Tools:
|
||||
- akash_docs_list_routes: discover common docs routes from akash.network
|
||||
- akash_docs_fetch: fetch a docs page (prefers GitHub markdown, falls back to site HTML)
|
||||
- akash_docs_search: keyword search across discovered routes (cached)
|
||||
- akash_sdl_snippet: generate a minimal Akash SDL template
|
||||
"""
|
||||
|
||||
7
mcp/akash_docs/__main__.py
Normal file
7
mcp/akash_docs/__main__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .server import main
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
861
mcp/akash_docs/server.py
Normal file
861
mcp/akash_docs/server.py
Normal file
@@ -0,0 +1,861 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import urllib.error
|
||||
import urllib.parse
|
||||
import urllib.request
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
|
||||
AKASH_SITE_BASE = "https://akash.network"
|
||||
AKASH_DOCS_BASE = f"{AKASH_SITE_BASE}/docs"
|
||||
|
||||
AKASH_DOCS_GITHUB_OWNER = "akash-network"
|
||||
AKASH_DOCS_GITHUB_REPO = "website-revamp"
|
||||
AKASH_DOCS_GITHUB_REF_DEFAULT = "main"
|
||||
AKASH_DOCS_GITHUB_DOCS_ROOT = "src/content/Docs"
|
||||
|
||||
MAX_BYTES_DEFAULT = 32_000
|
||||
|
||||
|
||||
def _repo_root() -> Path:
|
||||
# server.py -> akash_docs -> mcp -> cloudflare -> <repo root>
|
||||
return Path(__file__).resolve().parents[3]
|
||||
|
||||
|
||||
def _utc_now_iso() -> str:
|
||||
return datetime.now(timezone.utc).isoformat()
|
||||
|
||||
|
||||
def _max_bytes() -> int:
|
||||
raw = (os.getenv("VM_MCP_MAX_BYTES") or "").strip()
|
||||
if not raw:
|
||||
return MAX_BYTES_DEFAULT
|
||||
try:
|
||||
return max(4_096, int(raw))
|
||||
except ValueError:
|
||||
return MAX_BYTES_DEFAULT
|
||||
|
||||
|
||||
def _sha256_hex(text: str) -> str:
|
||||
return hashlib.sha256(text.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
def _http_get(url: str, *, timeout: int = 30) -> str:
|
||||
req = urllib.request.Request(
|
||||
url=url,
|
||||
headers={
|
||||
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8",
|
||||
"User-Agent": "work-core-akamcp/0.1 (+https://akash.network)",
|
||||
},
|
||||
)
|
||||
with urllib.request.urlopen(req, timeout=timeout) as resp:
|
||||
return resp.read().decode("utf-8", "replace")
|
||||
|
||||
|
||||
def _normalize_route(route_or_url: str) -> Tuple[str, str]:
|
||||
"""
|
||||
Returns (route, canonical_url).
|
||||
|
||||
route: "getting-started/what-is-akash" (no leading/trailing slashes)
|
||||
canonical_url: https://akash.network/docs/<route>
|
||||
"""
|
||||
raw = (route_or_url or "").strip()
|
||||
if not raw:
|
||||
return "", AKASH_DOCS_BASE + "/"
|
||||
|
||||
if raw.startswith("http://") or raw.startswith("https://"):
|
||||
parsed = urllib.parse.urlparse(raw)
|
||||
path = parsed.path or ""
|
||||
# Normalize to docs route if possible.
|
||||
if path in ("/docs", "/docs/"):
|
||||
return "", AKASH_DOCS_BASE + "/"
|
||||
if path.startswith("/docs/"):
|
||||
route = path[len("/docs/") :].strip("/")
|
||||
return route, f"{AKASH_DOCS_BASE}/{route}"
|
||||
return path.strip("/"), raw
|
||||
|
||||
# Accept "/docs/..." or "docs/..."
|
||||
route = raw.lstrip("/")
|
||||
if route in ("docs", "docs/"):
|
||||
return "", AKASH_DOCS_BASE + "/"
|
||||
if route.startswith("docs/"):
|
||||
route = route[len("docs/") :]
|
||||
route = route.strip("/")
|
||||
return route, f"{AKASH_DOCS_BASE}/{route}" if route else AKASH_DOCS_BASE + "/"
|
||||
|
||||
|
||||
def _strip_frontmatter(markdown: str) -> str:
|
||||
# Remove leading YAML frontmatter: ---\n...\n---\n
|
||||
if not markdown.startswith("---"):
|
||||
return markdown
|
||||
m = re.match(r"^---\s*\n.*?\n---\s*\n", markdown, flags=re.S)
|
||||
if not m:
|
||||
return markdown
|
||||
return markdown[m.end() :]
|
||||
|
||||
|
||||
def _github_candidates(route: str) -> List[str]:
|
||||
base = f"{AKASH_DOCS_GITHUB_DOCS_ROOT}/{route}".rstrip("/")
|
||||
candidates = [
|
||||
f"{base}/index.md",
|
||||
f"{base}/index.mdx",
|
||||
f"{base}.md",
|
||||
f"{base}.mdx",
|
||||
]
|
||||
# Handle root docs landing if route is empty.
|
||||
if not route:
|
||||
candidates = [
|
||||
f"{AKASH_DOCS_GITHUB_DOCS_ROOT}/index.md",
|
||||
f"{AKASH_DOCS_GITHUB_DOCS_ROOT}/index.mdx",
|
||||
]
|
||||
return candidates
|
||||
|
||||
|
||||
def _fetch_markdown_from_github(route: str, *, ref: str) -> Tuple[str, str, str]:
|
||||
"""
|
||||
Returns (markdown, raw_url, repo_path) or raises urllib.error.HTTPError.
|
||||
"""
|
||||
last_err: Optional[urllib.error.HTTPError] = None
|
||||
for repo_path in _github_candidates(route):
|
||||
raw_url = (
|
||||
f"https://raw.githubusercontent.com/{AKASH_DOCS_GITHUB_OWNER}/"
|
||||
f"{AKASH_DOCS_GITHUB_REPO}/{ref}/{repo_path}"
|
||||
)
|
||||
try:
|
||||
return _http_get(raw_url), raw_url, repo_path
|
||||
except urllib.error.HTTPError as e:
|
||||
if e.code == 404:
|
||||
last_err = e
|
||||
continue
|
||||
raise
|
||||
if last_err:
|
||||
raise last_err
|
||||
raise urllib.error.HTTPError(
|
||||
url="",
|
||||
code=404,
|
||||
msg="Not Found",
|
||||
hdrs=None,
|
||||
fp=None,
|
||||
)
|
||||
|
||||
|
||||
def _extract_article_html(page_html: str) -> str:
|
||||
m = re.search(r"<article\b[^>]*>(.*?)</article>", page_html, flags=re.S | re.I)
|
||||
if m:
|
||||
return m.group(1)
|
||||
m = re.search(r"<main\b[^>]*>(.*?)</main>", page_html, flags=re.S | re.I)
|
||||
if m:
|
||||
return m.group(1)
|
||||
return page_html
|
||||
|
||||
|
||||
def _html_to_text(article_html: str) -> str:
|
||||
# Drop scripts/styles
|
||||
cleaned = re.sub(
|
||||
r"<(script|style)\b[^>]*>.*?</\1>", "", article_html, flags=re.S | re.I
|
||||
)
|
||||
|
||||
# Preserve code blocks a bit better (Astro uses <div class="ec-line"> for each line)
|
||||
def _pre_repl(match: re.Match[str]) -> str:
|
||||
pre = match.group(0)
|
||||
pre = re.sub(r"</div>\s*", "\n", pre, flags=re.I)
|
||||
pre = re.sub(r"<div\b[^>]*>", "", pre, flags=re.I)
|
||||
pre = re.sub(r"<br\s*/?>", "\n", pre, flags=re.I)
|
||||
pre = re.sub(r"<[^>]+>", "", pre)
|
||||
return "\n```\n" + _html_unescape(pre).strip() + "\n```\n"
|
||||
|
||||
cleaned = re.sub(r"<pre\b[^>]*>.*?</pre>", _pre_repl, cleaned, flags=re.S | re.I)
|
||||
|
||||
# Newlines for common block tags
|
||||
cleaned = re.sub(
|
||||
r"</(p|h1|h2|h3|h4|h5|h6|li|blockquote)>", "\n", cleaned, flags=re.I
|
||||
)
|
||||
cleaned = re.sub(r"<br\s*/?>", "\n", cleaned, flags=re.I)
|
||||
cleaned = re.sub(r"<hr\b[^>]*>", "\n---\n", cleaned, flags=re.I)
|
||||
|
||||
# Strip remaining tags
|
||||
cleaned = re.sub(r"<[^>]+>", "", cleaned)
|
||||
|
||||
text = _html_unescape(cleaned)
|
||||
lines = [ln.rstrip() for ln in text.splitlines()]
|
||||
# Collapse excessive blank lines
|
||||
out: List[str] = []
|
||||
blank = False
|
||||
for ln in lines:
|
||||
if ln.strip() == "":
|
||||
if blank:
|
||||
continue
|
||||
blank = True
|
||||
out.append("")
|
||||
continue
|
||||
blank = False
|
||||
out.append(ln.strip())
|
||||
return "\n".join(out).strip()
|
||||
|
||||
|
||||
def _html_unescape(text: str) -> str:
|
||||
# Avoid importing html module repeatedly; do it lazily.
|
||||
import html as _html # local import to keep global import list small
|
||||
|
||||
return _html.unescape(text)
|
||||
|
||||
|
||||
def _discover_routes_from_docs_index() -> List[str]:
|
||||
html = _http_get(AKASH_DOCS_BASE + "/")
|
||||
hrefs = set(re.findall(r'href=\"(/docs/[^\"#?]+)\"', html))
|
||||
routes: List[str] = []
|
||||
for href in sorted(hrefs):
|
||||
route, _url = _normalize_route(href)
|
||||
if route:
|
||||
routes.append(route)
|
||||
return routes
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CachedDoc:
|
||||
cache_key: str
|
||||
fetched_at: str
|
||||
source: str
|
||||
route: str
|
||||
url: str
|
||||
ref: str
|
||||
content_path: str
|
||||
|
||||
|
||||
class DocStore:
|
||||
def __init__(self, root_dir: Path) -> None:
|
||||
self.root_dir = root_dir
|
||||
self.pages_dir = root_dir / "pages"
|
||||
self.index_path = root_dir / "index.json"
|
||||
self.pages_dir.mkdir(parents=True, exist_ok=True)
|
||||
self._index: Dict[str, Dict[str, Any]] = {}
|
||||
if self.index_path.exists():
|
||||
try:
|
||||
self._index = json.loads(self.index_path.read_text(encoding="utf-8"))
|
||||
except Exception:
|
||||
self._index = {}
|
||||
|
||||
def _write_index(self) -> None:
|
||||
tmp = self.index_path.with_suffix(".tmp")
|
||||
tmp.write_text(
|
||||
json.dumps(self._index, ensure_ascii=False, indent=2) + "\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
tmp.replace(self.index_path)
|
||||
|
||||
def get(self, cache_key: str) -> Optional[CachedDoc]:
|
||||
raw = self._index.get(cache_key)
|
||||
if not raw:
|
||||
return None
|
||||
path = Path(raw.get("content_path") or "")
|
||||
if not path.exists():
|
||||
return None
|
||||
return CachedDoc(
|
||||
cache_key=cache_key,
|
||||
fetched_at=str(raw.get("fetched_at") or ""),
|
||||
source=str(raw.get("source") or ""),
|
||||
route=str(raw.get("route") or ""),
|
||||
url=str(raw.get("url") or ""),
|
||||
ref=str(raw.get("ref") or ""),
|
||||
content_path=str(path),
|
||||
)
|
||||
|
||||
def save(
|
||||
self,
|
||||
*,
|
||||
cache_key: str,
|
||||
source: str,
|
||||
route: str,
|
||||
url: str,
|
||||
ref: str,
|
||||
content: str,
|
||||
) -> CachedDoc:
|
||||
content_hash = _sha256_hex(f"{source}:{ref}:{url}")[:20]
|
||||
path = self.pages_dir / f"{content_hash}.txt"
|
||||
path.write_text(content, encoding="utf-8")
|
||||
entry = {
|
||||
"fetched_at": _utc_now_iso(),
|
||||
"source": source,
|
||||
"route": route,
|
||||
"url": url,
|
||||
"ref": ref,
|
||||
"content_path": str(path),
|
||||
}
|
||||
self._index[cache_key] = entry
|
||||
self._write_index()
|
||||
return self.get(cache_key) or CachedDoc(
|
||||
cache_key=cache_key,
|
||||
fetched_at=entry["fetched_at"],
|
||||
source=source,
|
||||
route=route,
|
||||
url=url,
|
||||
ref=ref,
|
||||
content_path=str(path),
|
||||
)
|
||||
|
||||
|
||||
def _default_state_dir() -> Path:
|
||||
return _repo_root() / "archive_runtime" / "akash_docs_mcp"
|
||||
|
||||
|
||||
def _truncate_to_max_bytes(text: str, *, max_bytes: int) -> Tuple[str, bool]:
|
||||
blob = text.encode("utf-8")
|
||||
if len(blob) <= max_bytes:
|
||||
return text, False
|
||||
# Reserve a bit for the truncation notice
|
||||
reserve = min(512, max_bytes // 10)
|
||||
head = blob[: max(0, max_bytes - reserve)].decode("utf-8", "replace")
|
||||
head = head.rstrip() + "\n\n[TRUNCATED: response exceeded VM_MCP_MAX_BYTES]\n"
|
||||
return head, True
|
||||
|
||||
|
||||
def _mcp_text_result(text: str, *, is_error: bool = False) -> Dict[str, Any]:
|
||||
text, _truncated = _truncate_to_max_bytes(text, max_bytes=_max_bytes())
|
||||
result: Dict[str, Any] = {"content": [{"type": "text", "text": text}]}
|
||||
if is_error:
|
||||
result["isError"] = True
|
||||
return result
|
||||
|
||||
|
||||
class AkashDocsTools:
|
||||
def __init__(self) -> None:
|
||||
state_dir = Path(os.getenv("VM_AKASH_DOCS_MCP_STATE_DIR") or _default_state_dir())
|
||||
self.store = DocStore(state_dir)
|
||||
|
||||
def akash_docs_list_routes(self) -> Dict[str, Any]:
|
||||
routes = _discover_routes_from_docs_index()
|
||||
return {
|
||||
"ok": True,
|
||||
"summary": f"Discovered {len(routes)} docs route(s) from {AKASH_DOCS_BASE}/.",
|
||||
"data": {"routes": routes},
|
||||
"next_steps": ["akash_docs_fetch(route_or_url=...)"],
|
||||
}
|
||||
|
||||
def akash_docs_fetch(
|
||||
self,
|
||||
*,
|
||||
route_or_url: str,
|
||||
source: str = "auto",
|
||||
ref: str = AKASH_DOCS_GITHUB_REF_DEFAULT,
|
||||
max_chars: int = 12_000,
|
||||
refresh: bool = False,
|
||||
strip_frontmatter: bool = True,
|
||||
) -> Dict[str, Any]:
|
||||
route, canonical_url = _normalize_route(route_or_url)
|
||||
source_norm = (source or "auto").strip().lower()
|
||||
if source_norm not in ("auto", "github", "site"):
|
||||
raise ValueError("source must be one of: auto, github, site")
|
||||
|
||||
max_chars_int = max(0, int(max_chars))
|
||||
# Avoid flooding clients; open content_path for full content.
|
||||
max_chars_int = min(max_chars_int, max(2_000, _max_bytes() - 8_000))
|
||||
|
||||
cache_key = f"{source_norm}:{ref}:{route or canonical_url}"
|
||||
cached = self.store.get(cache_key)
|
||||
if cached and not refresh:
|
||||
content = Path(cached.content_path).read_text(encoding="utf-8")
|
||||
if strip_frontmatter and cached.source == "github":
|
||||
content = _strip_frontmatter(content)
|
||||
truncated = len(content) > max_chars_int
|
||||
return {
|
||||
"ok": True,
|
||||
"summary": "Returned cached docs content.",
|
||||
"data": {
|
||||
"source": cached.source,
|
||||
"route": cached.route,
|
||||
"url": cached.url,
|
||||
"ref": cached.ref,
|
||||
"cached": True,
|
||||
"fetched_at": cached.fetched_at,
|
||||
"content": content[:max_chars_int],
|
||||
"truncated": truncated,
|
||||
"content_path": cached.content_path,
|
||||
},
|
||||
"next_steps": ["Set refresh=true to refetch."],
|
||||
}
|
||||
|
||||
attempted: List[Dict[str, Any]] = []
|
||||
|
||||
def _try_github() -> Optional[Tuple[str, str, str]]:
|
||||
try:
|
||||
md, raw_url, repo_path = _fetch_markdown_from_github(route, ref=ref)
|
||||
return md, raw_url, repo_path
|
||||
except urllib.error.HTTPError as e:
|
||||
attempted.append({"source": "github", "status": getattr(e, "code", None), "detail": str(e)})
|
||||
return None
|
||||
|
||||
def _try_site() -> Optional[Tuple[str, str]]:
|
||||
try:
|
||||
html = _http_get(canonical_url)
|
||||
article = _extract_article_html(html)
|
||||
text = _html_to_text(article)
|
||||
return text, canonical_url
|
||||
except urllib.error.HTTPError as e:
|
||||
attempted.append({"source": "site", "status": getattr(e, "code", None), "detail": str(e)})
|
||||
return None
|
||||
|
||||
content: str
|
||||
final_source: str
|
||||
final_url: str
|
||||
extra: Dict[str, Any] = {}
|
||||
|
||||
if source_norm in ("auto", "github"):
|
||||
gh = _try_github()
|
||||
if gh:
|
||||
content, final_url, repo_path = gh
|
||||
final_source = "github"
|
||||
extra["repo_path"] = repo_path
|
||||
elif source_norm == "github":
|
||||
raise ValueError("GitHub fetch failed; try source='site' or verify the route/ref.")
|
||||
else:
|
||||
site = _try_site()
|
||||
if not site:
|
||||
raise ValueError(f"Fetch failed for route_or_url={route_or_url!r}. Attempts: {attempted}")
|
||||
content, final_url = site
|
||||
final_source = "site"
|
||||
else:
|
||||
site = _try_site()
|
||||
if not site:
|
||||
raise ValueError(f"Site fetch failed for route_or_url={route_or_url!r}. Attempts: {attempted}")
|
||||
content, final_url = site
|
||||
final_source = "site"
|
||||
|
||||
cached_doc = self.store.save(
|
||||
cache_key=cache_key,
|
||||
source=final_source,
|
||||
route=route,
|
||||
url=final_url,
|
||||
ref=ref,
|
||||
content=content,
|
||||
)
|
||||
|
||||
content_view = content
|
||||
if strip_frontmatter and final_source == "github":
|
||||
content_view = _strip_frontmatter(content_view)
|
||||
truncated = len(content_view) > max_chars_int
|
||||
content_out = content_view[:max_chars_int]
|
||||
return {
|
||||
"ok": True,
|
||||
"summary": f"Fetched docs via {final_source}.",
|
||||
"data": {
|
||||
"source": final_source,
|
||||
"route": route,
|
||||
"url": final_url,
|
||||
"ref": ref,
|
||||
"cached": False,
|
||||
"fetched_at": cached_doc.fetched_at,
|
||||
"content": content_out,
|
||||
"truncated": truncated,
|
||||
"content_path": cached_doc.content_path,
|
||||
"attempts": attempted,
|
||||
**extra,
|
||||
},
|
||||
"next_steps": [
|
||||
"akash_docs_search(query=..., refresh=false)",
|
||||
],
|
||||
}
|
||||
|
||||
def akash_docs_search(
|
||||
self,
|
||||
*,
|
||||
query: str,
|
||||
limit: int = 10,
|
||||
refresh: bool = False,
|
||||
ref: str = AKASH_DOCS_GITHUB_REF_DEFAULT,
|
||||
) -> Dict[str, Any]:
|
||||
q = (query or "").strip()
|
||||
if not q:
|
||||
raise ValueError("query is required")
|
||||
limit = max(1, min(50, int(limit)))
|
||||
|
||||
routes = _discover_routes_from_docs_index()
|
||||
hits: List[Dict[str, Any]] = []
|
||||
|
||||
for route in routes:
|
||||
doc = self.akash_docs_fetch(
|
||||
route_or_url=route,
|
||||
source="github",
|
||||
ref=ref,
|
||||
max_chars=0, # search reads full content from content_path
|
||||
refresh=refresh,
|
||||
strip_frontmatter=True,
|
||||
)
|
||||
data = doc.get("data") or {}
|
||||
content_path = data.get("content_path")
|
||||
if not content_path:
|
||||
continue
|
||||
try:
|
||||
content = Path(str(content_path)).read_text(encoding="utf-8")
|
||||
content = _strip_frontmatter(content)
|
||||
except Exception:
|
||||
continue
|
||||
idx = content.lower().find(q.lower())
|
||||
if idx == -1:
|
||||
continue
|
||||
start = max(0, idx - 80)
|
||||
end = min(len(content), idx + 160)
|
||||
snippet = content[start:end].replace("\n", " ").strip()
|
||||
hits.append(
|
||||
{
|
||||
"route": route,
|
||||
"url": data.get("url"),
|
||||
"source": data.get("source"),
|
||||
"snippet": snippet,
|
||||
}
|
||||
)
|
||||
if len(hits) >= limit:
|
||||
break
|
||||
|
||||
return {
|
||||
"ok": True,
|
||||
"summary": f"Found {len(hits)} hit(s) across {len(routes)} route(s).",
|
||||
"data": {"query": q, "hits": hits, "routes_searched": len(routes)},
|
||||
"next_steps": ["akash_docs_fetch(route_or_url=hits[0].route)"],
|
||||
}
|
||||
|
||||
def akash_sdl_snippet(
|
||||
self,
|
||||
*,
|
||||
service_name: str,
|
||||
container_image: str,
|
||||
port: int,
|
||||
cpu_units: float = 0.5,
|
||||
memory_size: str = "512Mi",
|
||||
storage_size: str = "512Mi",
|
||||
denom: str = "uakt",
|
||||
price_amount: int = 100,
|
||||
) -> Dict[str, Any]:
|
||||
svc = (service_name or "").strip()
|
||||
img = (container_image or "").strip()
|
||||
if not svc:
|
||||
raise ValueError("service_name is required")
|
||||
if not img:
|
||||
raise ValueError("container_image is required")
|
||||
port_int = int(port)
|
||||
if port_int <= 0 or port_int > 65535:
|
||||
raise ValueError("port must be 1..65535")
|
||||
|
||||
sdl = f"""version: \"2.0\"
|
||||
|
||||
services:
|
||||
{svc}:
|
||||
image: {img}
|
||||
expose:
|
||||
- port: {port_int}
|
||||
to:
|
||||
- global: true
|
||||
|
||||
profiles:
|
||||
compute:
|
||||
{svc}:
|
||||
resources:
|
||||
cpu:
|
||||
units: {cpu_units}
|
||||
memory:
|
||||
size: {memory_size}
|
||||
storage:
|
||||
size: {storage_size}
|
||||
placement:
|
||||
akash:
|
||||
pricing:
|
||||
{svc}:
|
||||
denom: {denom}
|
||||
amount: {int(price_amount)}
|
||||
|
||||
deployment:
|
||||
{svc}:
|
||||
akash:
|
||||
profile: {svc}
|
||||
count: 1
|
||||
"""
|
||||
return {
|
||||
"ok": True,
|
||||
"summary": "Generated an Akash SDL template.",
|
||||
"data": {
|
||||
"service_name": svc,
|
||||
"container_image": img,
|
||||
"port": port_int,
|
||||
"sdl": sdl,
|
||||
},
|
||||
"next_steps": [
|
||||
"Save as deploy.yaml and deploy via Akash Console or akash CLI.",
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
TOOLS: List[Dict[str, Any]] = [
|
||||
{
|
||||
"name": "akash_docs_list_routes",
|
||||
"description": "Discover common Akash docs routes by scraping https://akash.network/docs/ (SSR HTML).",
|
||||
"inputSchema": {"type": "object", "properties": {}},
|
||||
},
|
||||
{
|
||||
"name": "akash_docs_fetch",
|
||||
"description": "Fetch an Akash docs page (prefers GitHub markdown in akash-network/website-revamp; falls back to site HTML).",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"route_or_url": {"type": "string"},
|
||||
"source": {
|
||||
"type": "string",
|
||||
"description": "auto|github|site",
|
||||
"default": "auto",
|
||||
},
|
||||
"ref": {"type": "string", "default": AKASH_DOCS_GITHUB_REF_DEFAULT},
|
||||
"max_chars": {"type": "integer", "default": 12000},
|
||||
"refresh": {"type": "boolean", "default": False},
|
||||
"strip_frontmatter": {"type": "boolean", "default": True},
|
||||
},
|
||||
"required": ["route_or_url"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "akash_docs_search",
|
||||
"description": "Keyword search across routes discovered from /docs (fetches + caches GitHub markdown).",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string"},
|
||||
"limit": {"type": "integer", "default": 10},
|
||||
"refresh": {"type": "boolean", "default": False},
|
||||
"ref": {"type": "string", "default": AKASH_DOCS_GITHUB_REF_DEFAULT},
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "akash_sdl_snippet",
|
||||
"description": "Generate a minimal Akash SDL manifest for a single service exposing one port.",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"service_name": {"type": "string"},
|
||||
"container_image": {"type": "string"},
|
||||
"port": {"type": "integer"},
|
||||
"cpu_units": {"type": "number", "default": 0.5},
|
||||
"memory_size": {"type": "string", "default": "512Mi"},
|
||||
"storage_size": {"type": "string", "default": "512Mi"},
|
||||
"denom": {"type": "string", "default": "uakt"},
|
||||
"price_amount": {"type": "integer", "default": 100},
|
||||
},
|
||||
"required": ["service_name", "container_image", "port"],
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
class StdioJsonRpc:
|
||||
def __init__(self) -> None:
|
||||
self._in = sys.stdin.buffer
|
||||
self._out = sys.stdout.buffer
|
||||
self._mode: str | None = None # "headers" | "line"
|
||||
|
||||
def read_message(self) -> Optional[Dict[str, Any]]:
|
||||
while True:
|
||||
if self._mode == "line":
|
||||
line = self._in.readline()
|
||||
if not line:
|
||||
return None
|
||||
raw = line.decode("utf-8", "replace").strip()
|
||||
if not raw:
|
||||
continue
|
||||
try:
|
||||
msg = json.loads(raw)
|
||||
except Exception:
|
||||
continue
|
||||
if isinstance(msg, dict):
|
||||
return msg
|
||||
continue
|
||||
|
||||
first = self._in.readline()
|
||||
if not first:
|
||||
return None
|
||||
|
||||
if first in (b"\r\n", b"\n"):
|
||||
continue
|
||||
|
||||
# Auto-detect newline-delimited JSON framing.
|
||||
if self._mode is None and first.lstrip().startswith(b"{"):
|
||||
try:
|
||||
msg = json.loads(first.decode("utf-8", "replace"))
|
||||
except Exception:
|
||||
msg = None
|
||||
if isinstance(msg, dict):
|
||||
self._mode = "line"
|
||||
return msg
|
||||
|
||||
headers: Dict[str, str] = {}
|
||||
try:
|
||||
text = first.decode("utf-8", "replace").strip()
|
||||
except Exception:
|
||||
continue
|
||||
if ":" not in text:
|
||||
continue
|
||||
k, v = text.split(":", 1)
|
||||
headers[k.lower().strip()] = v.strip()
|
||||
|
||||
while True:
|
||||
line = self._in.readline()
|
||||
if not line:
|
||||
return None
|
||||
if line in (b"\r\n", b"\n"):
|
||||
break
|
||||
try:
|
||||
text = line.decode("utf-8", "replace").strip()
|
||||
except Exception:
|
||||
continue
|
||||
if ":" not in text:
|
||||
continue
|
||||
k, v = text.split(":", 1)
|
||||
headers[k.lower().strip()] = v.strip()
|
||||
|
||||
if "content-length" not in headers:
|
||||
return None
|
||||
try:
|
||||
length = int(headers["content-length"])
|
||||
except ValueError:
|
||||
return None
|
||||
body = self._in.read(length)
|
||||
if not body:
|
||||
return None
|
||||
self._mode = "headers"
|
||||
msg = json.loads(body.decode("utf-8", "replace"))
|
||||
if isinstance(msg, dict):
|
||||
return msg
|
||||
return None
|
||||
|
||||
def write_message(self, message: Dict[str, Any]) -> None:
|
||||
if self._mode == "line":
|
||||
payload = json.dumps(
|
||||
message, ensure_ascii=False, separators=(",", ":"), default=str
|
||||
).encode("utf-8")
|
||||
self._out.write(payload + b"\n")
|
||||
self._out.flush()
|
||||
return
|
||||
|
||||
body = json.dumps(message, ensure_ascii=False, separators=(",", ":")).encode(
|
||||
"utf-8"
|
||||
)
|
||||
header = f"Content-Length: {len(body)}\r\n\r\n".encode("utf-8")
|
||||
self._out.write(header)
|
||||
self._out.write(body)
|
||||
self._out.flush()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
tools = AkashDocsTools()
|
||||
rpc = StdioJsonRpc()
|
||||
|
||||
handlers: Dict[str, Callable[[Dict[str, Any]], Dict[str, Any]]] = {
|
||||
"akash_docs_list_routes": lambda a: tools.akash_docs_list_routes(),
|
||||
"akash_docs_fetch": lambda a: tools.akash_docs_fetch(**a),
|
||||
"akash_docs_search": lambda a: tools.akash_docs_search(**a),
|
||||
"akash_sdl_snippet": lambda a: tools.akash_sdl_snippet(**a),
|
||||
}
|
||||
|
||||
while True:
|
||||
msg = rpc.read_message()
|
||||
if msg is None:
|
||||
return
|
||||
|
||||
method = msg.get("method")
|
||||
msg_id = msg.get("id")
|
||||
params = msg.get("params") or {}
|
||||
|
||||
try:
|
||||
if method == "initialize":
|
||||
result = {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"serverInfo": {"name": "akash_docs", "version": "0.1.0"},
|
||||
"capabilities": {"tools": {}},
|
||||
}
|
||||
rpc.write_message({"jsonrpc": "2.0", "id": msg_id, "result": result})
|
||||
continue
|
||||
|
||||
if method == "tools/list":
|
||||
rpc.write_message(
|
||||
{"jsonrpc": "2.0", "id": msg_id, "result": {"tools": TOOLS}}
|
||||
)
|
||||
continue
|
||||
|
||||
if method == "tools/call":
|
||||
tool_name = str(params.get("name") or "")
|
||||
args = params.get("arguments") or {}
|
||||
if tool_name not in handlers:
|
||||
rpc.write_message(
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"id": msg_id,
|
||||
"result": _mcp_text_result(
|
||||
f"Unknown tool: {tool_name}\nKnown tools: {', '.join(sorted(handlers.keys()))}",
|
||||
is_error=True,
|
||||
),
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
payload = handlers[tool_name](args)
|
||||
# Split payload: meta JSON + optional raw content.
|
||||
# If payload["data"]["content"] exists, emit it as a second text block for readability.
|
||||
data = payload.get("data") if isinstance(payload, dict) else None
|
||||
content_text = None
|
||||
if isinstance(data, dict) and isinstance(data.get("content"), str):
|
||||
content_text = data["content"]
|
||||
data = dict(data)
|
||||
data.pop("content", None)
|
||||
payload = dict(payload)
|
||||
payload["data"] = data
|
||||
|
||||
blocks = [json.dumps(payload, ensure_ascii=False, indent=2)]
|
||||
if content_text:
|
||||
blocks.append(content_text)
|
||||
result: Dict[str, Any] = {
|
||||
"content": [{"type": "text", "text": b} for b in blocks]
|
||||
}
|
||||
rpc.write_message({"jsonrpc": "2.0", "id": msg_id, "result": result})
|
||||
except Exception as e: # noqa: BLE001
|
||||
rpc.write_message(
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"id": msg_id,
|
||||
"result": _mcp_text_result(
|
||||
f"Error: {e}",
|
||||
is_error=True,
|
||||
),
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
# Ignore notifications.
|
||||
if msg_id is None:
|
||||
continue
|
||||
|
||||
rpc.write_message(
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"id": msg_id,
|
||||
"result": _mcp_text_result(
|
||||
f"Unsupported method: {method}",
|
||||
is_error=True,
|
||||
),
|
||||
}
|
||||
)
|
||||
except Exception as e: # noqa: BLE001
|
||||
# Last-resort: avoid crashing the server.
|
||||
if msg_id is not None:
|
||||
rpc.write_message(
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"id": msg_id,
|
||||
"result": _mcp_text_result(f"fatal error: {e}", is_error=True),
|
||||
}
|
||||
)
|
||||
11
mcp/cloudflare_safe/__init__.py
Normal file
11
mcp/cloudflare_safe/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
"""
|
||||
cloudflare_safe MCP server.
|
||||
|
||||
Summary-first Cloudflare tooling with hard output caps and default redaction.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
__all__ = ["__version__"]
|
||||
|
||||
__version__ = "0.1.0"
|
||||
6
mcp/cloudflare_safe/__main__.py
Normal file
6
mcp/cloudflare_safe/__main__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .server import main
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
496
mcp/cloudflare_safe/cloudflare_api.py
Normal file
496
mcp/cloudflare_safe/cloudflare_api.py
Normal file
@@ -0,0 +1,496 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import urllib.error
|
||||
import urllib.parse
|
||||
import urllib.request
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Mapping,
|
||||
MutableMapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
)
|
||||
|
||||
CF_API_BASE = "https://api.cloudflare.com/client/v4"
|
||||
|
||||
|
||||
def utc_now_iso() -> str:
|
||||
return datetime.now(timezone.utc).isoformat()
|
||||
|
||||
|
||||
def stable_hash(data: Any) -> str:
|
||||
blob = json.dumps(
|
||||
data, sort_keys=True, separators=(",", ":"), ensure_ascii=False
|
||||
).encode("utf-8")
|
||||
return hashlib.sha256(blob).hexdigest()
|
||||
|
||||
|
||||
class CloudflareError(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CloudflareContext:
|
||||
api_token: str
|
||||
account_id: str
|
||||
|
||||
@staticmethod
|
||||
def from_env() -> "CloudflareContext":
|
||||
api_token = (
|
||||
os.getenv("CLOUDFLARE_API_TOKEN")
|
||||
or os.getenv("CF_API_TOKEN")
|
||||
or os.getenv("CLOUDFLARE_TOKEN")
|
||||
or ""
|
||||
).strip()
|
||||
account_id = (
|
||||
os.getenv("CLOUDFLARE_ACCOUNT_ID") or os.getenv("CF_ACCOUNT_ID") or ""
|
||||
).strip()
|
||||
|
||||
if not api_token:
|
||||
raise CloudflareError(
|
||||
"Missing Cloudflare API token. Set CLOUDFLARE_API_TOKEN (or CF_API_TOKEN)."
|
||||
)
|
||||
if not account_id:
|
||||
raise CloudflareError(
|
||||
"Missing Cloudflare account id. Set CLOUDFLARE_ACCOUNT_ID (or CF_ACCOUNT_ID)."
|
||||
)
|
||||
return CloudflareContext(api_token=api_token, account_id=account_id)
|
||||
|
||||
|
||||
class CloudflareClient:
|
||||
def __init__(self, *, api_token: str) -> None:
|
||||
self.api_token = api_token
|
||||
|
||||
def _request(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
*,
|
||||
params: Optional[Mapping[str, str]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
url = f"{CF_API_BASE}{path}"
|
||||
if params:
|
||||
url = f"{url}?{urllib.parse.urlencode(params)}"
|
||||
|
||||
req = urllib.request.Request(
|
||||
url=url,
|
||||
method=method,
|
||||
headers={
|
||||
"Authorization": f"Bearer {self.api_token}",
|
||||
"Accept": "application/json",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
with urllib.request.urlopen(req, timeout=30) as resp:
|
||||
raw = resp.read()
|
||||
except urllib.error.HTTPError as e:
|
||||
raw = e.read() if hasattr(e, "read") else b""
|
||||
detail = raw.decode("utf-8", "replace")
|
||||
raise CloudflareError(
|
||||
f"Cloudflare API HTTP {e.code} for {path}: {detail}"
|
||||
) from e
|
||||
except urllib.error.URLError as e:
|
||||
raise CloudflareError(
|
||||
f"Cloudflare API request failed for {path}: {e}"
|
||||
) from e
|
||||
|
||||
try:
|
||||
data = json.loads(raw.decode("utf-8", "replace"))
|
||||
except json.JSONDecodeError:
|
||||
raise CloudflareError(
|
||||
f"Cloudflare API returned non-JSON for {path}: {raw[:200]!r}"
|
||||
)
|
||||
|
||||
if not data.get("success", True):
|
||||
raise CloudflareError(
|
||||
f"Cloudflare API error for {path}: {data.get('errors')}"
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
def paginate(
|
||||
self,
|
||||
path: str,
|
||||
*,
|
||||
params: Optional[Mapping[str, str]] = None,
|
||||
per_page: int = 100,
|
||||
max_pages: int = 5,
|
||||
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
|
||||
"""
|
||||
Fetch a paginated Cloudflare endpoint.
|
||||
|
||||
Returns (results, result_info).
|
||||
"""
|
||||
results: List[Dict[str, Any]] = []
|
||||
page = 1
|
||||
last_info: Dict[str, Any] = {}
|
||||
|
||||
while True:
|
||||
merged_params: Dict[str, str] = {
|
||||
"page": str(page),
|
||||
"per_page": str(per_page),
|
||||
}
|
||||
if params:
|
||||
merged_params.update({k: str(v) for k, v in params.items()})
|
||||
|
||||
data = self._request("GET", path, params=merged_params)
|
||||
batch = data.get("result") or []
|
||||
if not isinstance(batch, list):
|
||||
batch = [batch]
|
||||
results.extend(batch)
|
||||
last_info = data.get("result_info") or {}
|
||||
|
||||
total_pages = int(last_info.get("total_pages") or 1)
|
||||
if page >= total_pages or page >= max_pages:
|
||||
break
|
||||
page += 1
|
||||
|
||||
return results, last_info
|
||||
|
||||
def list_zones(self) -> List[Dict[str, Any]]:
|
||||
zones, _info = self.paginate("/zones", max_pages=2)
|
||||
return zones
|
||||
|
||||
def list_dns_records_summary(
|
||||
self, zone_id: str, *, max_pages: int = 1
|
||||
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
|
||||
return self.paginate(f"/zones/{zone_id}/dns_records", max_pages=max_pages)
|
||||
|
||||
def list_tunnels(self, account_id: str) -> List[Dict[str, Any]]:
|
||||
tunnels, _info = self.paginate(
|
||||
f"/accounts/{account_id}/cfd_tunnel", max_pages=2
|
||||
)
|
||||
return tunnels
|
||||
|
||||
def list_tunnel_connections(
|
||||
self, account_id: str, tunnel_id: str
|
||||
) -> List[Dict[str, Any]]:
|
||||
data = self._request(
|
||||
"GET", f"/accounts/{account_id}/cfd_tunnel/{tunnel_id}/connections"
|
||||
)
|
||||
result = data.get("result") or []
|
||||
return result if isinstance(result, list) else [result]
|
||||
|
||||
def list_access_apps(self, account_id: str) -> List[Dict[str, Any]]:
|
||||
apps, _info = self.paginate(f"/accounts/{account_id}/access/apps", max_pages=3)
|
||||
return apps
|
||||
|
||||
def list_access_policies(
|
||||
self, account_id: str, app_id: str
|
||||
) -> List[Dict[str, Any]]:
|
||||
policies, _info = self.paginate(
|
||||
f"/accounts/{account_id}/access/apps/{app_id}/policies",
|
||||
max_pages=3,
|
||||
)
|
||||
return policies
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SnapshotMeta:
|
||||
snapshot_id: str
|
||||
created_at: str
|
||||
scopes: List[str]
|
||||
snapshot_path: str
|
||||
|
||||
|
||||
class SnapshotStore:
|
||||
def __init__(self, root_dir: Path) -> None:
|
||||
self.root_dir = root_dir
|
||||
self.snapshots_dir = root_dir / "snapshots"
|
||||
self.diffs_dir = root_dir / "diffs"
|
||||
self.snapshots_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.diffs_dir.mkdir(parents=True, exist_ok=True)
|
||||
self._index: Dict[str, SnapshotMeta] = {}
|
||||
|
||||
def get(self, snapshot_id: str) -> SnapshotMeta:
|
||||
if snapshot_id not in self._index:
|
||||
raise CloudflareError(f"Unknown snapshot_id: {snapshot_id}")
|
||||
return self._index[snapshot_id]
|
||||
|
||||
def load_snapshot(self, snapshot_id: str) -> Dict[str, Any]:
|
||||
meta = self.get(snapshot_id)
|
||||
return json.loads(Path(meta.snapshot_path).read_text(encoding="utf-8"))
|
||||
|
||||
def create_snapshot(
|
||||
self,
|
||||
*,
|
||||
client: CloudflareClient,
|
||||
ctx: CloudflareContext,
|
||||
scopes: Sequence[str],
|
||||
zone_id: Optional[str] = None,
|
||||
zone_name: Optional[str] = None,
|
||||
dns_max_pages: int = 1,
|
||||
) -> Tuple[SnapshotMeta, Dict[str, Any]]:
|
||||
scopes_norm = sorted(set(scopes))
|
||||
created_at = utc_now_iso()
|
||||
|
||||
zones = client.list_zones()
|
||||
zones_min = [
|
||||
{
|
||||
"id": z.get("id"),
|
||||
"name": z.get("name"),
|
||||
"status": z.get("status"),
|
||||
"paused": z.get("paused"),
|
||||
}
|
||||
for z in zones
|
||||
]
|
||||
|
||||
selected_zone_id = zone_id
|
||||
if not selected_zone_id and zone_name:
|
||||
for z in zones_min:
|
||||
if z.get("name") == zone_name:
|
||||
selected_zone_id = str(z.get("id"))
|
||||
break
|
||||
|
||||
snapshot: Dict[str, Any] = {
|
||||
"meta": {
|
||||
"snapshot_id": "",
|
||||
"created_at": created_at,
|
||||
"account_id": ctx.account_id,
|
||||
"scopes": scopes_norm,
|
||||
},
|
||||
"zones": zones_min,
|
||||
}
|
||||
|
||||
if "tunnels" in scopes_norm:
|
||||
tunnels = client.list_tunnels(ctx.account_id)
|
||||
tunnels_min: List[Dict[str, Any]] = []
|
||||
for t in tunnels:
|
||||
tid = t.get("id")
|
||||
name = t.get("name")
|
||||
status = t.get("status")
|
||||
connector_count: Optional[int] = None
|
||||
last_seen: Optional[str] = None
|
||||
if tid and status != "deleted":
|
||||
conns = client.list_tunnel_connections(ctx.account_id, str(tid))
|
||||
connector_count = len(conns)
|
||||
# Pick the most recent 'opened_at' if present.
|
||||
opened = [c.get("opened_at") for c in conns if isinstance(c, dict)]
|
||||
opened = [o for o in opened if isinstance(o, str)]
|
||||
last_seen = max(opened) if opened else None
|
||||
|
||||
tunnels_min.append(
|
||||
{
|
||||
"id": tid,
|
||||
"name": name,
|
||||
"status": status,
|
||||
"created_at": t.get("created_at"),
|
||||
"deleted_at": t.get("deleted_at"),
|
||||
"connector_count": connector_count,
|
||||
"last_seen": last_seen,
|
||||
}
|
||||
)
|
||||
snapshot["tunnels"] = tunnels_min
|
||||
|
||||
if "access_apps" in scopes_norm:
|
||||
apps = client.list_access_apps(ctx.account_id)
|
||||
apps_min = [
|
||||
{
|
||||
"id": a.get("id"),
|
||||
"name": a.get("name"),
|
||||
"domain": a.get("domain"),
|
||||
"type": a.get("type"),
|
||||
"created_at": a.get("created_at"),
|
||||
"updated_at": a.get("updated_at"),
|
||||
}
|
||||
for a in apps
|
||||
]
|
||||
snapshot["access_apps"] = apps_min
|
||||
|
||||
if "dns" in scopes_norm:
|
||||
if selected_zone_id:
|
||||
records, info = client.list_dns_records_summary(
|
||||
selected_zone_id, max_pages=dns_max_pages
|
||||
)
|
||||
records_min = [
|
||||
{
|
||||
"id": r.get("id"),
|
||||
"type": r.get("type"),
|
||||
"name": r.get("name"),
|
||||
"content": r.get("content"),
|
||||
"proxied": r.get("proxied"),
|
||||
"ttl": r.get("ttl"),
|
||||
}
|
||||
for r in records
|
||||
]
|
||||
snapshot["dns"] = {
|
||||
"zone_id": selected_zone_id,
|
||||
"zone_name": zone_name,
|
||||
"result_info": info,
|
||||
"records_sample": records_min,
|
||||
}
|
||||
else:
|
||||
snapshot["dns"] = {
|
||||
"note": "dns scope requested but no zone_id/zone_name provided; only zones list included",
|
||||
}
|
||||
|
||||
snapshot_id = f"cf_{created_at.replace(':', '').replace('-', '').replace('.', '')}_{stable_hash(snapshot)[:10]}"
|
||||
snapshot["meta"]["snapshot_id"] = snapshot_id
|
||||
|
||||
path = self.snapshots_dir / f"{snapshot_id}.json"
|
||||
path.write_text(
|
||||
json.dumps(snapshot, indent=2, ensure_ascii=False), encoding="utf-8"
|
||||
)
|
||||
|
||||
meta = SnapshotMeta(
|
||||
snapshot_id=snapshot_id,
|
||||
created_at=created_at,
|
||||
scopes=scopes_norm,
|
||||
snapshot_path=str(path),
|
||||
)
|
||||
self._index[snapshot_id] = meta
|
||||
return meta, snapshot
|
||||
|
||||
def diff(
|
||||
self,
|
||||
*,
|
||||
from_snapshot_id: str,
|
||||
to_snapshot_id: str,
|
||||
scopes: Optional[Sequence[str]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
before = self.load_snapshot(from_snapshot_id)
|
||||
after = self.load_snapshot(to_snapshot_id)
|
||||
|
||||
scopes_before = set(before.get("meta", {}).get("scopes") or [])
|
||||
scopes_after = set(after.get("meta", {}).get("scopes") or [])
|
||||
scopes_all = sorted(scopes_before | scopes_after)
|
||||
scopes_use = sorted(set(scopes or scopes_all))
|
||||
|
||||
def index_by_id(
|
||||
items: Iterable[Mapping[str, Any]],
|
||||
) -> Dict[str, Dict[str, Any]]:
|
||||
out: Dict[str, Dict[str, Any]] = {}
|
||||
for it in items:
|
||||
_id = it.get("id")
|
||||
if _id is None:
|
||||
continue
|
||||
out[str(_id)] = dict(it)
|
||||
return out
|
||||
|
||||
diff_out: Dict[str, Any] = {
|
||||
"from": from_snapshot_id,
|
||||
"to": to_snapshot_id,
|
||||
"scopes": scopes_use,
|
||||
"changes": {},
|
||||
}
|
||||
|
||||
for scope in scopes_use:
|
||||
if scope not in {"tunnels", "access_apps", "zones"}:
|
||||
continue
|
||||
b_items = before.get(scope) or []
|
||||
a_items = after.get(scope) or []
|
||||
if not isinstance(b_items, list) or not isinstance(a_items, list):
|
||||
continue
|
||||
b_map = index_by_id(b_items)
|
||||
a_map = index_by_id(a_items)
|
||||
added = [a_map[k] for k in sorted(set(a_map) - set(b_map))]
|
||||
removed = [b_map[k] for k in sorted(set(b_map) - set(a_map))]
|
||||
|
||||
changed: List[Dict[str, Any]] = []
|
||||
for k in sorted(set(a_map) & set(b_map)):
|
||||
if stable_hash(a_map[k]) != stable_hash(b_map[k]):
|
||||
changed.append({"id": k, "before": b_map[k], "after": a_map[k]})
|
||||
|
||||
diff_out["changes"][scope] = {
|
||||
"added": [{"id": x.get("id"), "name": x.get("name")} for x in added],
|
||||
"removed": [
|
||||
{"id": x.get("id"), "name": x.get("name")} for x in removed
|
||||
],
|
||||
"changed": [
|
||||
{"id": x.get("id"), "name": x.get("after", {}).get("name")}
|
||||
for x in changed
|
||||
],
|
||||
"counts": {
|
||||
"added": len(added),
|
||||
"removed": len(removed),
|
||||
"changed": len(changed),
|
||||
},
|
||||
}
|
||||
|
||||
diff_path = self.diffs_dir / f"{from_snapshot_id}_to_{to_snapshot_id}.json"
|
||||
diff_path.write_text(
|
||||
json.dumps(diff_out, indent=2, ensure_ascii=False),
|
||||
encoding="utf-8",
|
||||
)
|
||||
diff_out["diff_path"] = str(diff_path)
|
||||
return diff_out
|
||||
|
||||
|
||||
def parse_cloudflared_config_ingress(config_text: str) -> List[Dict[str, str]]:
|
||||
"""
|
||||
Best-effort parser for cloudflared YAML config ingress rules.
|
||||
|
||||
We intentionally avoid a YAML dependency; this extracts common patterns:
|
||||
- hostname: example.com
|
||||
service: http://127.0.0.1:8080
|
||||
"""
|
||||
rules: List[Dict[str, str]] = []
|
||||
lines = config_text.splitlines()
|
||||
i = 0
|
||||
while i < len(lines):
|
||||
line = lines[i]
|
||||
stripped = line.lstrip()
|
||||
if not stripped.startswith("-"):
|
||||
i += 1
|
||||
continue
|
||||
after_dash = stripped[1:].lstrip()
|
||||
if not after_dash.startswith("hostname:"):
|
||||
i += 1
|
||||
continue
|
||||
|
||||
hostname = after_dash[len("hostname:") :].strip().strip('"').strip("'")
|
||||
base_indent = len(line) - len(line.lstrip())
|
||||
|
||||
service = ""
|
||||
j = i + 1
|
||||
while j < len(lines):
|
||||
next_line = lines[j]
|
||||
if next_line.strip() == "":
|
||||
j += 1
|
||||
continue
|
||||
|
||||
next_indent = len(next_line) - len(next_line.lstrip())
|
||||
if next_indent <= base_indent:
|
||||
break
|
||||
|
||||
next_stripped = next_line.lstrip()
|
||||
if next_stripped.startswith("service:"):
|
||||
service = next_stripped[len("service:") :].strip().strip('"').strip("'")
|
||||
break
|
||||
j += 1
|
||||
|
||||
rules.append({"hostname": hostname, "service": service})
|
||||
i = j
|
||||
return rules
|
||||
|
||||
|
||||
def ingress_summary_from_file(
|
||||
*,
|
||||
config_path: str,
|
||||
max_rules: int = 50,
|
||||
) -> Dict[str, Any]:
|
||||
path = Path(config_path)
|
||||
if not path.exists():
|
||||
raise CloudflareError(f"cloudflared config not found: {config_path}")
|
||||
text = path.read_text(encoding="utf-8", errors="replace")
|
||||
rules = parse_cloudflared_config_ingress(text)
|
||||
hostnames = sorted({r["hostname"] for r in rules if r.get("hostname")})
|
||||
return {
|
||||
"config_path": config_path,
|
||||
"ingress_rule_count": len(rules),
|
||||
"hostnames": hostnames[:max_rules],
|
||||
"rules_sample": rules[:max_rules],
|
||||
"truncated": len(rules) > max_rules,
|
||||
}
|
||||
725
mcp/cloudflare_safe/server.py
Normal file
725
mcp/cloudflare_safe/server.py
Normal file
@@ -0,0 +1,725 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
from .cloudflare_api import (
|
||||
CloudflareClient,
|
||||
CloudflareContext,
|
||||
CloudflareError,
|
||||
SnapshotStore,
|
||||
ingress_summary_from_file,
|
||||
)
|
||||
|
||||
MAX_BYTES_DEFAULT = 32_000
|
||||
|
||||
|
||||
def _repo_root() -> Path:
|
||||
# server.py -> cloudflare_safe -> mcp -> <repo root>
|
||||
return Path(__file__).resolve().parents[3]
|
||||
|
||||
|
||||
def _max_bytes() -> int:
|
||||
raw = (os.getenv("VM_MCP_MAX_BYTES") or "").strip()
|
||||
if not raw:
|
||||
return MAX_BYTES_DEFAULT
|
||||
try:
|
||||
return max(4_096, int(raw))
|
||||
except ValueError:
|
||||
return MAX_BYTES_DEFAULT
|
||||
|
||||
|
||||
def _redact(obj: Any) -> Any:
|
||||
sensitive_keys = ("token", "secret", "password", "private", "key", "certificate")
|
||||
|
||||
if isinstance(obj, dict):
|
||||
out: Dict[str, Any] = {}
|
||||
for k, v in obj.items():
|
||||
if any(s in str(k).lower() for s in sensitive_keys):
|
||||
out[k] = "<REDACTED>"
|
||||
else:
|
||||
out[k] = _redact(v)
|
||||
return out
|
||||
if isinstance(obj, list):
|
||||
return [_redact(v) for v in obj]
|
||||
if isinstance(obj, str):
|
||||
if obj.startswith("ghp_") or obj.startswith("github_pat_"):
|
||||
return "<REDACTED>"
|
||||
return obj
|
||||
return obj
|
||||
|
||||
|
||||
def _safe_json(payload: Dict[str, Any]) -> str:
|
||||
payload = _redact(payload)
|
||||
raw = json.dumps(payload, ensure_ascii=False, separators=(",", ":"))
|
||||
if len(raw.encode("utf-8")) <= _max_bytes():
|
||||
return json.dumps(payload, ensure_ascii=False, indent=2)
|
||||
|
||||
# Truncate: keep only summary + next_steps.
|
||||
truncated = {
|
||||
"ok": payload.get("ok", True),
|
||||
"truncated": True,
|
||||
"summary": payload.get("summary", "Response exceeded max size; truncated."),
|
||||
"next_steps": payload.get(
|
||||
"next_steps",
|
||||
[
|
||||
"request a narrower scope (e.g., scopes=['tunnels'])",
|
||||
"request an export path instead of inline content",
|
||||
],
|
||||
),
|
||||
}
|
||||
return json.dumps(truncated, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
def _mcp_text_result(
|
||||
payload: Dict[str, Any], *, is_error: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
result: Dict[str, Any] = {
|
||||
"content": [{"type": "text", "text": _safe_json(payload)}]
|
||||
}
|
||||
if is_error:
|
||||
result["isError"] = True
|
||||
return result
|
||||
|
||||
|
||||
def _default_state_dir() -> Path:
|
||||
return _repo_root() / "archive_runtime" / "cloudflare_mcp"
|
||||
|
||||
|
||||
class CloudflareSafeTools:
|
||||
def __init__(self) -> None:
|
||||
self.store = SnapshotStore(
|
||||
Path(os.getenv("VM_CF_MCP_STATE_DIR") or _default_state_dir())
|
||||
)
|
||||
|
||||
def cf_snapshot(
|
||||
self,
|
||||
*,
|
||||
scopes: Optional[Sequence[str]] = None,
|
||||
zone_id: Optional[str] = None,
|
||||
zone_name: Optional[str] = None,
|
||||
dns_max_pages: int = 1,
|
||||
) -> Dict[str, Any]:
|
||||
scopes_use = list(scopes or ["tunnels", "access_apps"])
|
||||
ctx = CloudflareContext.from_env()
|
||||
client = CloudflareClient(api_token=ctx.api_token)
|
||||
meta, snapshot = self.store.create_snapshot(
|
||||
client=client,
|
||||
ctx=ctx,
|
||||
scopes=scopes_use,
|
||||
zone_id=zone_id,
|
||||
zone_name=zone_name,
|
||||
dns_max_pages=dns_max_pages,
|
||||
)
|
||||
|
||||
summary = (
|
||||
f"Snapshot {meta.snapshot_id} captured "
|
||||
f"(scopes={','.join(meta.scopes)}) and written to {meta.snapshot_path}."
|
||||
)
|
||||
return {
|
||||
"ok": True,
|
||||
"summary": summary,
|
||||
"data": {
|
||||
"snapshot_id": meta.snapshot_id,
|
||||
"created_at": meta.created_at,
|
||||
"scopes": meta.scopes,
|
||||
"snapshot_path": meta.snapshot_path,
|
||||
"counts": {
|
||||
"zones": len(snapshot.get("zones") or []),
|
||||
"tunnels": len(snapshot.get("tunnels") or []),
|
||||
"access_apps": len(snapshot.get("access_apps") or []),
|
||||
},
|
||||
},
|
||||
"truncated": False,
|
||||
"next_steps": [
|
||||
"cf_config_diff(from_snapshot_id=..., to_snapshot_id=...)",
|
||||
"cf_export_config(full=false, snapshot_id=...)",
|
||||
],
|
||||
}
|
||||
|
||||
def cf_refresh(
|
||||
self,
|
||||
*,
|
||||
snapshot_id: str,
|
||||
scopes: Optional[Sequence[str]] = None,
|
||||
dns_max_pages: int = 1,
|
||||
) -> Dict[str, Any]:
|
||||
before_meta = self.store.get(snapshot_id)
|
||||
before = self.store.load_snapshot(snapshot_id)
|
||||
scopes_use = list(scopes or (before.get("meta", {}).get("scopes") or []))
|
||||
|
||||
ctx = CloudflareContext.from_env()
|
||||
client = CloudflareClient(api_token=ctx.api_token)
|
||||
|
||||
meta, _snapshot = self.store.create_snapshot(
|
||||
client=client,
|
||||
ctx=ctx,
|
||||
scopes=scopes_use,
|
||||
zone_id=(before.get("dns") or {}).get("zone_id"),
|
||||
zone_name=(before.get("dns") or {}).get("zone_name"),
|
||||
dns_max_pages=dns_max_pages,
|
||||
)
|
||||
|
||||
return {
|
||||
"ok": True,
|
||||
"summary": f"Refreshed {before_meta.snapshot_id} -> {meta.snapshot_id} (scopes={','.join(meta.scopes)}).",
|
||||
"data": {
|
||||
"from_snapshot_id": before_meta.snapshot_id,
|
||||
"to_snapshot_id": meta.snapshot_id,
|
||||
"snapshot_path": meta.snapshot_path,
|
||||
},
|
||||
"truncated": False,
|
||||
"next_steps": [
|
||||
"cf_config_diff(from_snapshot_id=..., to_snapshot_id=...)",
|
||||
],
|
||||
}
|
||||
|
||||
def cf_config_diff(
|
||||
self,
|
||||
*,
|
||||
from_snapshot_id: str,
|
||||
to_snapshot_id: str,
|
||||
scopes: Optional[Sequence[str]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
diff = self.store.diff(
|
||||
from_snapshot_id=from_snapshot_id,
|
||||
to_snapshot_id=to_snapshot_id,
|
||||
scopes=scopes,
|
||||
)
|
||||
|
||||
# Keep the response small; point to diff_path for full detail.
|
||||
changes = diff.get("changes") or {}
|
||||
counts = {
|
||||
scope: (changes.get(scope) or {}).get("counts")
|
||||
for scope in sorted(changes.keys())
|
||||
}
|
||||
return {
|
||||
"ok": True,
|
||||
"summary": f"Diff computed and written to {diff.get('diff_path')}.",
|
||||
"data": {
|
||||
"from_snapshot_id": from_snapshot_id,
|
||||
"to_snapshot_id": to_snapshot_id,
|
||||
"scopes": diff.get("scopes"),
|
||||
"counts": counts,
|
||||
"diff_path": diff.get("diff_path"),
|
||||
},
|
||||
"truncated": False,
|
||||
"next_steps": [
|
||||
"Use filesystem MCP to open diff_path for full details",
|
||||
"Run cf_export_config(full=false, snapshot_id=...) for a safe export path",
|
||||
],
|
||||
}
|
||||
|
||||
def cf_export_config(
|
||||
self,
|
||||
*,
|
||||
snapshot_id: Optional[str] = None,
|
||||
full: bool = False,
|
||||
scopes: Optional[Sequence[str]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
if snapshot_id is None:
|
||||
snap = self.cf_snapshot(scopes=scopes)
|
||||
snapshot_id = str((snap.get("data") or {}).get("snapshot_id"))
|
||||
|
||||
meta = self.store.get(snapshot_id)
|
||||
if not full:
|
||||
return {
|
||||
"ok": True,
|
||||
"summary": "Export is summary-first; full config requires full=true.",
|
||||
"data": {
|
||||
"snapshot_id": meta.snapshot_id,
|
||||
"snapshot_path": meta.snapshot_path,
|
||||
},
|
||||
"truncated": False,
|
||||
"next_steps": [
|
||||
"Use filesystem MCP to open snapshot_path",
|
||||
"If you truly need inline data, call cf_export_config(full=true, snapshot_id=...)",
|
||||
],
|
||||
}
|
||||
|
||||
snapshot = self.store.load_snapshot(snapshot_id)
|
||||
return {
|
||||
"ok": True,
|
||||
"summary": "Full snapshot export (redacted + size-capped). Prefer snapshot_path for large data.",
|
||||
"data": snapshot,
|
||||
"truncated": False,
|
||||
"next_steps": [
|
||||
f"Snapshot file: {meta.snapshot_path}",
|
||||
],
|
||||
}
|
||||
|
||||
def cf_tunnel_status(
|
||||
self,
|
||||
*,
|
||||
snapshot_id: Optional[str] = None,
|
||||
tunnel_name: Optional[str] = None,
|
||||
tunnel_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
if snapshot_id:
|
||||
snap = self.store.load_snapshot(snapshot_id)
|
||||
tunnels = snap.get("tunnels") or []
|
||||
else:
|
||||
snap = self.cf_snapshot(scopes=["tunnels"])
|
||||
sid = str((snap.get("data") or {}).get("snapshot_id"))
|
||||
tunnels = self.store.load_snapshot(sid).get("tunnels") or []
|
||||
|
||||
def matches(t: Dict[str, Any]) -> bool:
|
||||
if tunnel_id and str(t.get("id")) != str(tunnel_id):
|
||||
return False
|
||||
if tunnel_name and str(t.get("name")) != str(tunnel_name):
|
||||
return False
|
||||
return True
|
||||
|
||||
filtered = [t for t in tunnels if isinstance(t, dict) and matches(t)]
|
||||
if not filtered and (tunnel_id or tunnel_name):
|
||||
return {
|
||||
"ok": False,
|
||||
"summary": "Tunnel not found in snapshot.",
|
||||
"data": {"tunnel_id": tunnel_id, "tunnel_name": tunnel_name},
|
||||
"truncated": False,
|
||||
"next_steps": ["Call cf_snapshot(scopes=['tunnels']) and retry."],
|
||||
}
|
||||
|
||||
connectors = [t.get("connector_count") for t in filtered if isinstance(t, dict)]
|
||||
connectors = [c for c in connectors if isinstance(c, int)]
|
||||
return {
|
||||
"ok": True,
|
||||
"summary": f"Returned {len(filtered)} tunnel(s).",
|
||||
"data": {
|
||||
"tunnels": [
|
||||
{
|
||||
"id": t.get("id"),
|
||||
"name": t.get("name"),
|
||||
"status": t.get("status"),
|
||||
"connector_count": t.get("connector_count"),
|
||||
"last_seen": t.get("last_seen"),
|
||||
}
|
||||
for t in filtered
|
||||
],
|
||||
"connectors_total": sum(connectors) if connectors else 0,
|
||||
},
|
||||
"truncated": False,
|
||||
"next_steps": [
|
||||
"For local ingress hostnames, use cf_tunnel_ingress_summary(config_path='/etc/cloudflared/config.yml')",
|
||||
],
|
||||
}
|
||||
|
||||
def cf_tunnel_ingress_summary(
|
||||
self,
|
||||
*,
|
||||
config_path: str = "/etc/cloudflared/config.yml",
|
||||
full: bool = False,
|
||||
max_rules: int = 50,
|
||||
) -> Dict[str, Any]:
|
||||
summary = ingress_summary_from_file(
|
||||
config_path=config_path, max_rules=max_rules
|
||||
)
|
||||
if not full:
|
||||
return {
|
||||
"ok": True,
|
||||
"summary": f"Parsed ingress hostnames from {config_path}.",
|
||||
"data": {
|
||||
"config_path": summary["config_path"],
|
||||
"ingress_rule_count": summary["ingress_rule_count"],
|
||||
"hostnames": summary["hostnames"],
|
||||
"truncated": summary["truncated"],
|
||||
},
|
||||
"truncated": False,
|
||||
"next_steps": [
|
||||
"Call cf_tunnel_ingress_summary(full=true, ...) to include service mappings (still capped).",
|
||||
],
|
||||
}
|
||||
return {
|
||||
"ok": True,
|
||||
"summary": f"Ingress summary (full=true) for {config_path}.",
|
||||
"data": summary,
|
||||
"truncated": False,
|
||||
"next_steps": [],
|
||||
}
|
||||
|
||||
def cf_access_policy_list(
|
||||
self,
|
||||
*,
|
||||
app_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
ctx = CloudflareContext.from_env()
|
||||
client = CloudflareClient(api_token=ctx.api_token)
|
||||
|
||||
if not app_id:
|
||||
apps = client.list_access_apps(ctx.account_id)
|
||||
apps_min = [
|
||||
{
|
||||
"id": a.get("id"),
|
||||
"name": a.get("name"),
|
||||
"domain": a.get("domain"),
|
||||
"type": a.get("type"),
|
||||
}
|
||||
for a in apps
|
||||
]
|
||||
return {
|
||||
"ok": True,
|
||||
"summary": f"Returned {len(apps_min)} Access app(s). Provide app_id to list policies.",
|
||||
"data": {"apps": apps_min},
|
||||
"truncated": False,
|
||||
"next_steps": [
|
||||
"Call cf_access_policy_list(app_id=...)",
|
||||
],
|
||||
}
|
||||
|
||||
policies = client.list_access_policies(ctx.account_id, app_id)
|
||||
policies_min = [
|
||||
{
|
||||
"id": p.get("id"),
|
||||
"name": p.get("name"),
|
||||
"decision": p.get("decision"),
|
||||
"precedence": p.get("precedence"),
|
||||
}
|
||||
for p in policies
|
||||
]
|
||||
return {
|
||||
"ok": True,
|
||||
"summary": f"Returned {len(policies_min)} policy/policies for app_id={app_id}.",
|
||||
"data": {"app_id": app_id, "policies": policies_min},
|
||||
"truncated": False,
|
||||
"next_steps": [],
|
||||
}
|
||||
|
||||
|
||||
TOOLS: List[Dict[str, Any]] = [
|
||||
{
|
||||
"name": "cf_snapshot",
|
||||
"description": "Create a summary-first Cloudflare state snapshot (writes JSON to disk; returns snapshot_id + paths).",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"scopes": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Scopes to fetch (default: ['tunnels','access_apps']). Supported: zones,tunnels,access_apps,dns",
|
||||
},
|
||||
"zone_id": {"type": "string"},
|
||||
"zone_name": {"type": "string"},
|
||||
"dns_max_pages": {"type": "integer", "default": 1},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "cf_refresh",
|
||||
"description": "Refresh a prior snapshot (creates a new snapshot_id).",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"snapshot_id": {"type": "string"},
|
||||
"scopes": {"type": "array", "items": {"type": "string"}},
|
||||
"dns_max_pages": {"type": "integer", "default": 1},
|
||||
},
|
||||
"required": ["snapshot_id"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "cf_config_diff",
|
||||
"description": "Diff two snapshots (summary counts inline; full diff written to disk).",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"from_snapshot_id": {"type": "string"},
|
||||
"to_snapshot_id": {"type": "string"},
|
||||
"scopes": {"type": "array", "items": {"type": "string"}},
|
||||
},
|
||||
"required": ["from_snapshot_id", "to_snapshot_id"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "cf_export_config",
|
||||
"description": "Export snapshot config. Defaults to summary-only; full=true returns redacted + size-capped data.",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"snapshot_id": {"type": "string"},
|
||||
"full": {"type": "boolean", "default": False},
|
||||
"scopes": {"type": "array", "items": {"type": "string"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "cf_tunnel_status",
|
||||
"description": "Return tunnel status summary (connector count, last seen).",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"snapshot_id": {"type": "string"},
|
||||
"tunnel_name": {"type": "string"},
|
||||
"tunnel_id": {"type": "string"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "cf_tunnel_ingress_summary",
|
||||
"description": "Parse cloudflared ingress hostnames from a local config file (never dumps full YAML unless full=true, still capped).",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"config_path": {
|
||||
"type": "string",
|
||||
"default": "/etc/cloudflared/config.yml",
|
||||
},
|
||||
"full": {"type": "boolean", "default": False},
|
||||
"max_rules": {"type": "integer", "default": 50},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "cf_access_policy_list",
|
||||
"description": "List Access apps, or policies for a specific app_id (summary-only).",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"app_id": {"type": "string"},
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
class StdioJsonRpc:
|
||||
def __init__(self) -> None:
|
||||
self._in = sys.stdin.buffer
|
||||
self._out = sys.stdout.buffer
|
||||
self._mode: str | None = None # "headers" | "line"
|
||||
|
||||
def read_message(self) -> Optional[Dict[str, Any]]:
|
||||
while True:
|
||||
if self._mode == "line":
|
||||
line = self._in.readline()
|
||||
if not line:
|
||||
return None
|
||||
raw = line.decode("utf-8", "replace").strip()
|
||||
if not raw:
|
||||
continue
|
||||
try:
|
||||
msg = json.loads(raw)
|
||||
except Exception:
|
||||
continue
|
||||
if isinstance(msg, dict):
|
||||
return msg
|
||||
continue
|
||||
|
||||
first = self._in.readline()
|
||||
if not first:
|
||||
return None
|
||||
|
||||
if first in (b"\r\n", b"\n"):
|
||||
continue
|
||||
|
||||
# Auto-detect newline-delimited JSON framing.
|
||||
if self._mode is None and first.lstrip().startswith(b"{"):
|
||||
try:
|
||||
msg = json.loads(first.decode("utf-8", "replace"))
|
||||
except Exception:
|
||||
msg = None
|
||||
if isinstance(msg, dict):
|
||||
self._mode = "line"
|
||||
return msg
|
||||
|
||||
headers: Dict[str, str] = {}
|
||||
try:
|
||||
text = first.decode("utf-8", "replace").strip()
|
||||
except Exception:
|
||||
continue
|
||||
if ":" not in text:
|
||||
continue
|
||||
k, v = text.split(":", 1)
|
||||
headers[k.lower().strip()] = v.strip()
|
||||
|
||||
while True:
|
||||
line = self._in.readline()
|
||||
if not line:
|
||||
return None
|
||||
if line in (b"\r\n", b"\n"):
|
||||
break
|
||||
try:
|
||||
text = line.decode("utf-8", "replace").strip()
|
||||
except Exception:
|
||||
continue
|
||||
if ":" not in text:
|
||||
continue
|
||||
k, v = text.split(":", 1)
|
||||
headers[k.lower().strip()] = v.strip()
|
||||
|
||||
if "content-length" not in headers:
|
||||
return None
|
||||
try:
|
||||
length = int(headers["content-length"])
|
||||
except ValueError:
|
||||
return None
|
||||
body = self._in.read(length)
|
||||
if not body:
|
||||
return None
|
||||
self._mode = "headers"
|
||||
msg = json.loads(body.decode("utf-8", "replace"))
|
||||
if isinstance(msg, dict):
|
||||
return msg
|
||||
return None
|
||||
|
||||
def write_message(self, message: Dict[str, Any]) -> None:
|
||||
if self._mode == "line":
|
||||
payload = json.dumps(
|
||||
message, ensure_ascii=False, separators=(",", ":"), default=str
|
||||
).encode("utf-8")
|
||||
self._out.write(payload + b"\n")
|
||||
self._out.flush()
|
||||
return
|
||||
|
||||
body = json.dumps(message, ensure_ascii=False, separators=(",", ":")).encode(
|
||||
"utf-8"
|
||||
)
|
||||
header = f"Content-Length: {len(body)}\r\n\r\n".encode("utf-8")
|
||||
self._out.write(header)
|
||||
self._out.write(body)
|
||||
self._out.flush()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
tools = CloudflareSafeTools()
|
||||
rpc = StdioJsonRpc()
|
||||
|
||||
handlers: Dict[str, Callable[[Dict[str, Any]], Dict[str, Any]]] = {
|
||||
"cf_snapshot": lambda a: tools.cf_snapshot(**a),
|
||||
"cf_refresh": lambda a: tools.cf_refresh(**a),
|
||||
"cf_config_diff": lambda a: tools.cf_config_diff(**a),
|
||||
"cf_export_config": lambda a: tools.cf_export_config(**a),
|
||||
"cf_tunnel_status": lambda a: tools.cf_tunnel_status(**a),
|
||||
"cf_tunnel_ingress_summary": lambda a: tools.cf_tunnel_ingress_summary(**a),
|
||||
"cf_access_policy_list": lambda a: tools.cf_access_policy_list(**a),
|
||||
}
|
||||
|
||||
while True:
|
||||
msg = rpc.read_message()
|
||||
if msg is None:
|
||||
return
|
||||
|
||||
method = msg.get("method")
|
||||
msg_id = msg.get("id")
|
||||
params = msg.get("params") or {}
|
||||
|
||||
try:
|
||||
if method == "initialize":
|
||||
result = {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"serverInfo": {"name": "cloudflare_safe", "version": "0.1.0"},
|
||||
"capabilities": {"tools": {}},
|
||||
}
|
||||
rpc.write_message({"jsonrpc": "2.0", "id": msg_id, "result": result})
|
||||
continue
|
||||
|
||||
if method == "tools/list":
|
||||
rpc.write_message(
|
||||
{"jsonrpc": "2.0", "id": msg_id, "result": {"tools": TOOLS}}
|
||||
)
|
||||
continue
|
||||
|
||||
if method == "tools/call":
|
||||
tool_name = str(params.get("name") or "")
|
||||
args = params.get("arguments") or {}
|
||||
if tool_name not in handlers:
|
||||
rpc.write_message(
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"id": msg_id,
|
||||
"result": _mcp_text_result(
|
||||
{
|
||||
"ok": False,
|
||||
"summary": f"Unknown tool: {tool_name}",
|
||||
"data": {"known_tools": sorted(handlers.keys())},
|
||||
"truncated": False,
|
||||
"next_steps": ["Call tools/list"],
|
||||
},
|
||||
is_error=True,
|
||||
),
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
payload = handlers[tool_name](args)
|
||||
rpc.write_message(
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"id": msg_id,
|
||||
"result": _mcp_text_result(payload),
|
||||
}
|
||||
)
|
||||
except CloudflareError as e:
|
||||
rpc.write_message(
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"id": msg_id,
|
||||
"result": _mcp_text_result(
|
||||
{
|
||||
"ok": False,
|
||||
"summary": str(e),
|
||||
"truncated": False,
|
||||
"next_steps": [
|
||||
"Verify CLOUDFLARE_API_TOKEN and CLOUDFLARE_ACCOUNT_ID are set",
|
||||
"Retry with a narrower scope",
|
||||
],
|
||||
},
|
||||
is_error=True,
|
||||
),
|
||||
}
|
||||
)
|
||||
except Exception as e: # noqa: BLE001
|
||||
rpc.write_message(
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"id": msg_id,
|
||||
"result": _mcp_text_result(
|
||||
{
|
||||
"ok": False,
|
||||
"summary": f"Unhandled error: {e}",
|
||||
"truncated": False,
|
||||
"next_steps": ["Retry with a narrower scope"],
|
||||
},
|
||||
is_error=True,
|
||||
),
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
# Ignore notifications.
|
||||
if msg_id is None:
|
||||
continue
|
||||
|
||||
rpc.write_message(
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"id": msg_id,
|
||||
"result": _mcp_text_result(
|
||||
{
|
||||
"ok": False,
|
||||
"summary": f"Unsupported method: {method}",
|
||||
"truncated": False,
|
||||
},
|
||||
is_error=True,
|
||||
),
|
||||
}
|
||||
)
|
||||
except Exception as e: # noqa: BLE001
|
||||
# Last-resort: avoid crashing the server.
|
||||
if msg_id is not None:
|
||||
rpc.write_message(
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"id": msg_id,
|
||||
"result": _mcp_text_result(
|
||||
{
|
||||
"ok": False,
|
||||
"summary": f"fatal error: {e}",
|
||||
"truncated": False,
|
||||
},
|
||||
),
|
||||
}
|
||||
)
|
||||
6
mcp/oracle_answer/__main__.py
Normal file
6
mcp/oracle_answer/__main__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .server import main
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
386
mcp/oracle_answer/server.py
Normal file
386
mcp/oracle_answer/server.py
Normal file
@@ -0,0 +1,386 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
from layer0 import layer0_entry
|
||||
from layer0.shadow_classifier import ShadowEvalResult
|
||||
|
||||
from .tool import OracleAnswerTool
|
||||
|
||||
MAX_BYTES_DEFAULT = 32_000
|
||||
|
||||
|
||||
def _max_bytes() -> int:
|
||||
raw = (os.getenv("VM_MCP_MAX_BYTES") or "").strip()
|
||||
if not raw:
|
||||
return MAX_BYTES_DEFAULT
|
||||
try:
|
||||
return max(4_096, int(raw))
|
||||
except ValueError:
|
||||
return MAX_BYTES_DEFAULT
|
||||
|
||||
|
||||
def _redact(obj: Any) -> Any:
|
||||
sensitive_keys = ("token", "secret", "password", "private", "key", "certificate")
|
||||
|
||||
if isinstance(obj, dict):
|
||||
out: Dict[str, Any] = {}
|
||||
for k, v in obj.items():
|
||||
if any(s in str(k).lower() for s in sensitive_keys):
|
||||
out[k] = "<REDACTED>"
|
||||
else:
|
||||
out[k] = _redact(v)
|
||||
return out
|
||||
if isinstance(obj, list):
|
||||
return [_redact(v) for v in obj]
|
||||
if isinstance(obj, str):
|
||||
if obj.startswith("ghp_") or obj.startswith("github_pat_"):
|
||||
return "<REDACTED>"
|
||||
return obj
|
||||
return obj
|
||||
|
||||
|
||||
def _safe_json(payload: Dict[str, Any]) -> str:
|
||||
payload = _redact(payload)
|
||||
raw = json.dumps(payload, ensure_ascii=False, separators=(",", ":"), default=str)
|
||||
if len(raw.encode("utf-8")) <= _max_bytes():
|
||||
return json.dumps(payload, ensure_ascii=False, indent=2, default=str)
|
||||
|
||||
truncated = {
|
||||
"ok": payload.get("ok", True),
|
||||
"truncated": True,
|
||||
"summary": payload.get("summary", "Response exceeded max size; truncated."),
|
||||
"next_steps": payload.get(
|
||||
"next_steps",
|
||||
["request narrower outputs (e.g., fewer frameworks or shorter question)"],
|
||||
),
|
||||
}
|
||||
return json.dumps(truncated, ensure_ascii=False, indent=2, default=str)
|
||||
|
||||
|
||||
def _mcp_text_result(
|
||||
payload: Dict[str, Any], *, is_error: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
result: Dict[str, Any] = {
|
||||
"content": [{"type": "text", "text": _safe_json(payload)}]
|
||||
}
|
||||
if is_error:
|
||||
result["isError"] = True
|
||||
return result
|
||||
|
||||
|
||||
TOOLS: List[Dict[str, Any]] = [
|
||||
{
|
||||
"name": "oracle_answer",
|
||||
"description": "Answer a compliance/security question (optionally via NVIDIA LLM) and map to frameworks.",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"question": {
|
||||
"type": "string",
|
||||
"description": "The question to answer.",
|
||||
},
|
||||
"frameworks": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Frameworks to reference (e.g., ['NIST-CSF','ISO-27001','GDPR']).",
|
||||
},
|
||||
"mode": {
|
||||
"type": "string",
|
||||
"enum": ["strict", "advisory"],
|
||||
"default": "strict",
|
||||
"description": "strict=conservative, advisory=exploratory.",
|
||||
},
|
||||
"local_only": {
|
||||
"type": "boolean",
|
||||
"description": "If true, skip NVIDIA API calls (uses local-only mode). Defaults to true when NVIDIA_API_KEY is missing.",
|
||||
},
|
||||
},
|
||||
"required": ["question"],
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
class OracleAnswerTools:
|
||||
async def oracle_answer(
|
||||
self,
|
||||
*,
|
||||
question: str,
|
||||
frameworks: Optional[List[str]] = None,
|
||||
mode: str = "strict",
|
||||
local_only: Optional[bool] = None,
|
||||
) -> Dict[str, Any]:
|
||||
routing_action, shadow = layer0_entry(question)
|
||||
if routing_action != "HANDOFF_TO_LAYER1":
|
||||
return _layer0_payload(routing_action, shadow)
|
||||
|
||||
local_only_use = (
|
||||
bool(local_only)
|
||||
if local_only is not None
|
||||
else not bool((os.getenv("NVIDIA_API_KEY") or "").strip())
|
||||
)
|
||||
|
||||
try:
|
||||
tool = OracleAnswerTool(
|
||||
default_frameworks=frameworks,
|
||||
use_local_only=local_only_use,
|
||||
)
|
||||
except Exception as e: # noqa: BLE001
|
||||
return {
|
||||
"ok": False,
|
||||
"summary": str(e),
|
||||
"data": {
|
||||
"local_only": local_only_use,
|
||||
"has_nvidia_api_key": bool(
|
||||
(os.getenv("NVIDIA_API_KEY") or "").strip()
|
||||
),
|
||||
},
|
||||
"truncated": False,
|
||||
"next_steps": [
|
||||
"Set NVIDIA_API_KEY to enable live answers",
|
||||
"Or call oracle_answer(local_only=true, ...)",
|
||||
],
|
||||
}
|
||||
|
||||
resp = await tool.answer(question=question, frameworks=frameworks, mode=mode)
|
||||
return {
|
||||
"ok": True,
|
||||
"summary": "Oracle answer generated.",
|
||||
"data": {
|
||||
"question": question,
|
||||
"mode": mode,
|
||||
"frameworks": frameworks or tool.default_frameworks,
|
||||
"local_only": local_only_use,
|
||||
"model": resp.model,
|
||||
"answer": resp.answer,
|
||||
"framework_hits": resp.framework_hits,
|
||||
"reasoning": resp.reasoning,
|
||||
},
|
||||
"truncated": False,
|
||||
"next_steps": [
|
||||
"If the answer is incomplete, add more specifics to the question or include more frameworks.",
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
class StdioJsonRpc:
|
||||
def __init__(self) -> None:
|
||||
self._in = sys.stdin.buffer
|
||||
self._out = sys.stdout.buffer
|
||||
self._mode: str | None = None # "headers" | "line"
|
||||
|
||||
def read_message(self) -> Optional[Dict[str, Any]]:
|
||||
while True:
|
||||
if self._mode == "line":
|
||||
line = self._in.readline()
|
||||
if not line:
|
||||
return None
|
||||
raw = line.decode("utf-8", "replace").strip()
|
||||
if not raw:
|
||||
continue
|
||||
try:
|
||||
msg = json.loads(raw)
|
||||
except Exception:
|
||||
continue
|
||||
if isinstance(msg, dict):
|
||||
return msg
|
||||
continue
|
||||
|
||||
first = self._in.readline()
|
||||
if not first:
|
||||
return None
|
||||
|
||||
if first in (b"\r\n", b"\n"):
|
||||
continue
|
||||
|
||||
# Auto-detect newline-delimited JSON framing.
|
||||
if self._mode is None and first.lstrip().startswith(b"{"):
|
||||
try:
|
||||
msg = json.loads(first.decode("utf-8", "replace"))
|
||||
except Exception:
|
||||
msg = None
|
||||
if isinstance(msg, dict):
|
||||
self._mode = "line"
|
||||
return msg
|
||||
|
||||
headers: Dict[str, str] = {}
|
||||
try:
|
||||
text = first.decode("utf-8", "replace").strip()
|
||||
except Exception:
|
||||
continue
|
||||
if ":" not in text:
|
||||
continue
|
||||
k, v = text.split(":", 1)
|
||||
headers[k.lower().strip()] = v.strip()
|
||||
|
||||
while True:
|
||||
line = self._in.readline()
|
||||
if not line:
|
||||
return None
|
||||
if line in (b"\r\n", b"\n"):
|
||||
break
|
||||
try:
|
||||
text = line.decode("utf-8", "replace").strip()
|
||||
except Exception:
|
||||
continue
|
||||
if ":" not in text:
|
||||
continue
|
||||
k, v = text.split(":", 1)
|
||||
headers[k.lower().strip()] = v.strip()
|
||||
|
||||
if "content-length" not in headers:
|
||||
return None
|
||||
try:
|
||||
length = int(headers["content-length"])
|
||||
except ValueError:
|
||||
return None
|
||||
body = self._in.read(length)
|
||||
if not body:
|
||||
return None
|
||||
self._mode = "headers"
|
||||
msg = json.loads(body.decode("utf-8", "replace"))
|
||||
if isinstance(msg, dict):
|
||||
return msg
|
||||
return None
|
||||
|
||||
def write_message(self, message: Dict[str, Any]) -> None:
|
||||
if self._mode == "line":
|
||||
payload = json.dumps(
|
||||
message, ensure_ascii=False, separators=(",", ":"), default=str
|
||||
).encode("utf-8")
|
||||
self._out.write(payload + b"\n")
|
||||
self._out.flush()
|
||||
return
|
||||
|
||||
body = json.dumps(
|
||||
message, ensure_ascii=False, separators=(",", ":"), default=str
|
||||
).encode("utf-8")
|
||||
header = f"Content-Length: {len(body)}\r\n\r\n".encode("utf-8")
|
||||
self._out.write(header)
|
||||
self._out.write(body)
|
||||
self._out.flush()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
tools = OracleAnswerTools()
|
||||
rpc = StdioJsonRpc()
|
||||
|
||||
handlers: Dict[str, Callable[[Dict[str, Any]], Any]] = {
|
||||
"oracle_answer": lambda a: tools.oracle_answer(**a),
|
||||
}
|
||||
|
||||
while True:
|
||||
msg = rpc.read_message()
|
||||
if msg is None:
|
||||
return
|
||||
|
||||
method = msg.get("method")
|
||||
msg_id = msg.get("id")
|
||||
params = msg.get("params") or {}
|
||||
|
||||
try:
|
||||
if method == "initialize":
|
||||
result = {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"serverInfo": {"name": "oracle_answer", "version": "0.1.0"},
|
||||
"capabilities": {"tools": {}},
|
||||
}
|
||||
rpc.write_message({"jsonrpc": "2.0", "id": msg_id, "result": result})
|
||||
continue
|
||||
|
||||
if method == "tools/list":
|
||||
rpc.write_message(
|
||||
{"jsonrpc": "2.0", "id": msg_id, "result": {"tools": TOOLS}}
|
||||
)
|
||||
continue
|
||||
|
||||
if method == "tools/call":
|
||||
tool_name = str(params.get("name") or "")
|
||||
args = params.get("arguments") or {}
|
||||
handler = handlers.get(tool_name)
|
||||
if not handler:
|
||||
rpc.write_message(
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"id": msg_id,
|
||||
"result": _mcp_text_result(
|
||||
{
|
||||
"ok": False,
|
||||
"summary": f"Unknown tool: {tool_name}",
|
||||
"data": {"known_tools": sorted(handlers.keys())},
|
||||
"truncated": False,
|
||||
"next_steps": ["Call tools/list"],
|
||||
},
|
||||
is_error=True,
|
||||
),
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
payload = asyncio.run(handler(args)) # type: ignore[arg-type]
|
||||
is_error = (
|
||||
not bool(payload.get("ok", True))
|
||||
if isinstance(payload, dict)
|
||||
else False
|
||||
)
|
||||
rpc.write_message(
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"id": msg_id,
|
||||
"result": _mcp_text_result(payload, is_error=is_error),
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
# Ignore notifications.
|
||||
if msg_id is None:
|
||||
continue
|
||||
|
||||
rpc.write_message(
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"id": msg_id,
|
||||
"result": _mcp_text_result(
|
||||
{"ok": False, "summary": f"Unsupported method: {method}"},
|
||||
is_error=True,
|
||||
),
|
||||
}
|
||||
)
|
||||
except Exception as e: # noqa: BLE001
|
||||
if msg_id is not None:
|
||||
rpc.write_message(
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"id": msg_id,
|
||||
"result": _mcp_text_result(
|
||||
{"ok": False, "summary": f"fatal error: {e}"},
|
||||
is_error=True,
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _layer0_payload(routing_action: str, shadow: ShadowEvalResult) -> Dict[str, Any]:
|
||||
if routing_action == "FAIL_CLOSED":
|
||||
return {"ok": False, "summary": "Layer 0: cannot comply with this request."}
|
||||
if routing_action == "HANDOFF_TO_GUARDRAILS":
|
||||
reason = shadow.reason or "governance_violation"
|
||||
return {
|
||||
"ok": False,
|
||||
"summary": f"Layer 0: governance violation detected ({reason}).",
|
||||
}
|
||||
if routing_action == "PROMPT_FOR_CLARIFICATION":
|
||||
return {
|
||||
"ok": False,
|
||||
"summary": "Layer 0: request is ambiguous. Please clarify and retry.",
|
||||
}
|
||||
return {"ok": False, "summary": "Layer 0: unrecognized routing action; refusing."}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -9,7 +9,11 @@ Separate from CLI/API wrapper for clean testability.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import urllib.error
|
||||
import urllib.request
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
@@ -92,12 +96,10 @@ class OracleAnswerTool:
|
||||
if self.use_local_only:
|
||||
return "Local-only mode: skipping NVIDIA API call"
|
||||
|
||||
if not httpx:
|
||||
raise ImportError("httpx not installed. Install with: pip install httpx")
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Accept": "application/json",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
payload = {
|
||||
@@ -108,18 +110,45 @@ class OracleAnswerTool:
|
||||
"max_tokens": 1024,
|
||||
}
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
f"{self.NVIDIA_API_BASE}/chat/completions",
|
||||
json=payload,
|
||||
headers=headers,
|
||||
timeout=30.0,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
# Prefer httpx when available; otherwise fall back to stdlib urllib to avoid extra deps.
|
||||
if httpx:
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
f"{self.NVIDIA_API_BASE}/chat/completions",
|
||||
json=payload,
|
||||
headers=headers,
|
||||
timeout=30.0,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return data["choices"][0]["message"]["content"]
|
||||
except Exception as e: # noqa: BLE001
|
||||
return f"(API Error: {str(e)}) Falling back to local analysis..."
|
||||
|
||||
def _urllib_post() -> str:
|
||||
req = urllib.request.Request(
|
||||
url=f"{self.NVIDIA_API_BASE}/chat/completions",
|
||||
method="POST",
|
||||
headers=headers,
|
||||
data=json.dumps(payload, ensure_ascii=False).encode("utf-8"),
|
||||
)
|
||||
try:
|
||||
with urllib.request.urlopen(req, timeout=30) as resp:
|
||||
raw = resp.read().decode("utf-8", "replace")
|
||||
data = json.loads(raw)
|
||||
return data["choices"][0]["message"]["content"]
|
||||
except Exception as e:
|
||||
except urllib.error.HTTPError as e:
|
||||
detail = ""
|
||||
try:
|
||||
detail = e.read().decode("utf-8", "replace")
|
||||
except Exception:
|
||||
detail = str(e)
|
||||
raise RuntimeError(f"HTTP {e.code}: {detail}") from e
|
||||
|
||||
try:
|
||||
return await asyncio.to_thread(_urllib_post)
|
||||
except Exception as e: # noqa: BLE001
|
||||
return f"(API Error: {str(e)}) Falling back to local analysis..."
|
||||
|
||||
async def answer(
|
||||
|
||||
@@ -10,22 +10,24 @@ This module provides tools to:
|
||||
Export primary classes and functions:
|
||||
"""
|
||||
|
||||
from mcp.waf_intelligence.analyzer import (
|
||||
WAFRuleAnalyzer,
|
||||
RuleViolation,
|
||||
__version__ = "0.3.0"
|
||||
|
||||
from .analyzer import (
|
||||
AnalysisResult,
|
||||
RuleViolation,
|
||||
WAFRuleAnalyzer,
|
||||
)
|
||||
from mcp.waf_intelligence.generator import (
|
||||
WAFRuleGenerator,
|
||||
GeneratedRule,
|
||||
)
|
||||
from mcp.waf_intelligence.compliance import (
|
||||
from .compliance import (
|
||||
ComplianceMapper,
|
||||
FrameworkMapping,
|
||||
)
|
||||
from mcp.waf_intelligence.orchestrator import (
|
||||
WAFIntelligence,
|
||||
from .generator import (
|
||||
GeneratedRule,
|
||||
WAFRuleGenerator,
|
||||
)
|
||||
from .orchestrator import (
|
||||
WAFInsight,
|
||||
WAFIntelligence,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
|
||||
@@ -10,6 +10,7 @@ from typing import Any, Dict, List
|
||||
from layer0 import layer0_entry
|
||||
from layer0.shadow_classifier import ShadowEvalResult
|
||||
|
||||
from . import __version__ as WAF_INTEL_VERSION
|
||||
from .orchestrator import WAFInsight, WAFIntelligence
|
||||
|
||||
|
||||
@@ -56,11 +57,18 @@ def run_cli(argv: List[str] | None = None) -> int:
|
||||
action="store_true",
|
||||
help="Exit with non-zero code if any error-severity violations are found.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--version",
|
||||
action="version",
|
||||
version=f"%(prog)s {WAF_INTEL_VERSION}",
|
||||
)
|
||||
|
||||
args = parser.parse_args(argv)
|
||||
|
||||
# Layer 0: pre-boot Shadow Eval gate.
|
||||
routing_action, shadow = layer0_entry(f"waf_intel_cli file={args.file} limit={args.limit}")
|
||||
routing_action, shadow = layer0_entry(
|
||||
f"waf_intel_cli file={args.file} limit={args.limit}"
|
||||
)
|
||||
if routing_action != "HANDOFF_TO_LAYER1":
|
||||
_render_layer0_block(routing_action, shadow)
|
||||
return 1
|
||||
@@ -90,7 +98,9 @@ def run_cli(argv: List[str] | None = None) -> int:
|
||||
print(f"\nWAF Intelligence Report for: {path}\n{'-' * 72}")
|
||||
|
||||
if not insights:
|
||||
print("No high-severity, high-confidence issues detected based on current heuristics.")
|
||||
print(
|
||||
"No high-severity, high-confidence issues detected based on current heuristics."
|
||||
)
|
||||
return 0
|
||||
|
||||
for idx, insight in enumerate(insights, start=1):
|
||||
@@ -119,7 +129,9 @@ def run_cli(argv: List[str] | None = None) -> int:
|
||||
if insight.mappings:
|
||||
print("\nCompliance Mapping:")
|
||||
for mapping in insight.mappings:
|
||||
print(f" - {mapping.framework} {mapping.control_id}: {mapping.description}")
|
||||
print(
|
||||
f" - {mapping.framework} {mapping.control_id}: {mapping.description}"
|
||||
)
|
||||
|
||||
print()
|
||||
|
||||
|
||||
@@ -1,9 +1,16 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
MANAGED_WAF_RULESET_IDS = (
|
||||
# Cloudflare managed WAF ruleset IDs (last updated 2025-12-18).
|
||||
"efb7b8c949ac4650a09736fc376e9aee", # Cloudflare Managed Ruleset
|
||||
"4814384a9e5d4991b9815dcfc25d2f1f", # OWASP Core Ruleset
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RuleViolation:
|
||||
@@ -57,6 +64,20 @@ class WAFRuleAnalyzer:
|
||||
Analyze Cloudflare WAF rules from Terraform with a quality-first posture.
|
||||
"""
|
||||
|
||||
def _has_managed_waf_rules(self, text: str) -> bool:
|
||||
text_lower = text.lower()
|
||||
|
||||
if "managed_rules" in text_lower:
|
||||
return True
|
||||
|
||||
if re.search(r'phase\s*=\s*"http_request_firewall_managed"', text_lower):
|
||||
return True
|
||||
|
||||
if "cf.waf" in text_lower:
|
||||
return True
|
||||
|
||||
return any(ruleset_id in text_lower for ruleset_id in MANAGED_WAF_RULESET_IDS)
|
||||
|
||||
def analyze_file(
|
||||
self,
|
||||
path: str | Path,
|
||||
@@ -70,7 +91,7 @@ class WAFRuleAnalyzer:
|
||||
violations: List[RuleViolation] = []
|
||||
|
||||
# Example heuristic: no managed rules present
|
||||
if "managed_rules" not in text:
|
||||
if not self._has_managed_waf_rules(text):
|
||||
violations.append(
|
||||
RuleViolation(
|
||||
rule_id=None,
|
||||
@@ -102,7 +123,7 @@ class WAFRuleAnalyzer:
|
||||
violations=violations,
|
||||
metadata={
|
||||
"file_size": path.stat().st_size,
|
||||
"heuristics_version": "0.2.0",
|
||||
"heuristics_version": "0.3.0",
|
||||
},
|
||||
)
|
||||
|
||||
@@ -125,7 +146,7 @@ class WAFRuleAnalyzer:
|
||||
tmp_path = Path(source_name)
|
||||
violations: List[RuleViolation] = []
|
||||
|
||||
if "managed_rules" not in text:
|
||||
if not self._has_managed_waf_rules(text):
|
||||
violations.append(
|
||||
RuleViolation(
|
||||
rule_id=None,
|
||||
@@ -141,7 +162,7 @@ class WAFRuleAnalyzer:
|
||||
result = AnalysisResult(
|
||||
source=str(tmp_path),
|
||||
violations=violations,
|
||||
metadata={"heuristics_version": "0.2.0"},
|
||||
metadata={"heuristics_version": "0.3.0"},
|
||||
)
|
||||
|
||||
result.violations = result.top_violations(
|
||||
@@ -161,27 +182,37 @@ class WAFRuleAnalyzer:
|
||||
) -> AnalysisResult:
|
||||
"""
|
||||
Enhanced analysis using threat intelligence data.
|
||||
|
||||
|
||||
Args:
|
||||
path: WAF config file path
|
||||
threat_indicators: List of ThreatIndicator objects from threat_intel module
|
||||
min_severity: Minimum severity to include
|
||||
min_confidence: Minimum confidence threshold
|
||||
|
||||
|
||||
Returns:
|
||||
AnalysisResult with violations informed by threat intel
|
||||
"""
|
||||
# Start with base analysis
|
||||
base_result = self.analyze_file(path, min_severity=min_severity, min_confidence=min_confidence)
|
||||
|
||||
base_result = self.analyze_file(
|
||||
path, min_severity=min_severity, min_confidence=min_confidence
|
||||
)
|
||||
|
||||
path = Path(path)
|
||||
text = path.read_text(encoding="utf-8")
|
||||
text_lower = text.lower()
|
||||
|
||||
|
||||
# Check if threat indicators are addressed by existing rules
|
||||
critical_ips = [i for i in threat_indicators if i.indicator_type == "ip" and i.severity in ("critical", "high")]
|
||||
critical_patterns = [i for i in threat_indicators if i.indicator_type == "pattern" and i.severity in ("critical", "high")]
|
||||
|
||||
critical_ips = [
|
||||
i
|
||||
for i in threat_indicators
|
||||
if i.indicator_type == "ip" and i.severity in ("critical", "high")
|
||||
]
|
||||
critical_patterns = [
|
||||
i
|
||||
for i in threat_indicators
|
||||
if i.indicator_type == "pattern" and i.severity in ("critical", "high")
|
||||
]
|
||||
|
||||
# Check for IP blocking coverage
|
||||
if critical_ips:
|
||||
ip_block_present = "ip.src" in text_lower or "cf.client.ip" in text_lower
|
||||
@@ -197,14 +228,14 @@ class WAFRuleAnalyzer:
|
||||
hint=f"Add IP blocking rules for identified threat actors. Sample IPs: {', '.join(i.value for i in critical_ips[:3])}",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# Check for pattern-based attack coverage
|
||||
attack_types_seen = set()
|
||||
for ind in critical_patterns:
|
||||
for tag in ind.tags:
|
||||
if tag in ("sqli", "xss", "rce", "path_traversal"):
|
||||
attack_types_seen.add(tag)
|
||||
|
||||
|
||||
# Check managed ruleset coverage
|
||||
for attack_type in attack_types_seen:
|
||||
if attack_type not in text_lower and f'"{attack_type}"' not in text_lower:
|
||||
@@ -219,13 +250,12 @@ class WAFRuleAnalyzer:
|
||||
hint=f"Enable Cloudflare managed rules for {attack_type.upper()} protection.",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# Update metadata with threat intel stats
|
||||
base_result.metadata["threat_intel"] = {
|
||||
"critical_ips": len(critical_ips),
|
||||
"critical_patterns": len(critical_patterns),
|
||||
"attack_types_seen": list(attack_types_seen),
|
||||
}
|
||||
|
||||
return base_result
|
||||
|
||||
return base_result
|
||||
|
||||
632
mcp/waf_intelligence/mcp_server.py
Normal file
632
mcp/waf_intelligence/mcp_server.py
Normal file
@@ -0,0 +1,632 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import glob
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from dataclasses import asdict
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
from cloudflare.layer0 import layer0_entry
|
||||
from cloudflare.layer0.shadow_classifier import ShadowEvalResult
|
||||
|
||||
from .orchestrator import ThreatAssessment, WAFInsight, WAFIntelligence
|
||||
|
||||
MAX_BYTES_DEFAULT = 32_000
|
||||
|
||||
|
||||
def _cloudflare_root() -> Path:
|
||||
# mcp_server.py -> waf_intelligence -> mcp -> cloudflare
|
||||
return Path(__file__).resolve().parents[2]
|
||||
|
||||
|
||||
def _max_bytes() -> int:
|
||||
raw = (os.getenv("VM_MCP_MAX_BYTES") or "").strip()
|
||||
if not raw:
|
||||
return MAX_BYTES_DEFAULT
|
||||
try:
|
||||
return max(4_096, int(raw))
|
||||
except ValueError:
|
||||
return MAX_BYTES_DEFAULT
|
||||
|
||||
|
||||
def _redact(obj: Any) -> Any:
|
||||
sensitive_keys = ("token", "secret", "password", "private", "key", "certificate")
|
||||
|
||||
if isinstance(obj, dict):
|
||||
out: Dict[str, Any] = {}
|
||||
for k, v in obj.items():
|
||||
if any(s in str(k).lower() for s in sensitive_keys):
|
||||
out[k] = "<REDACTED>"
|
||||
else:
|
||||
out[k] = _redact(v)
|
||||
return out
|
||||
if isinstance(obj, list):
|
||||
return [_redact(v) for v in obj]
|
||||
if isinstance(obj, str):
|
||||
if obj.startswith("ghp_") or obj.startswith("github_pat_"):
|
||||
return "<REDACTED>"
|
||||
return obj
|
||||
return obj
|
||||
|
||||
|
||||
def _safe_json(payload: Dict[str, Any]) -> str:
|
||||
payload = _redact(payload)
|
||||
raw = json.dumps(payload, ensure_ascii=False, separators=(",", ":"), default=str)
|
||||
if len(raw.encode("utf-8")) <= _max_bytes():
|
||||
return json.dumps(payload, ensure_ascii=False, indent=2, default=str)
|
||||
|
||||
truncated = {
|
||||
"ok": payload.get("ok", True),
|
||||
"truncated": True,
|
||||
"summary": payload.get("summary", "Response exceeded max size; truncated."),
|
||||
"next_steps": payload.get(
|
||||
"next_steps",
|
||||
[
|
||||
"request fewer files/insights (limit=...)",
|
||||
"use higher min_severity to reduce output",
|
||||
],
|
||||
),
|
||||
}
|
||||
return json.dumps(truncated, ensure_ascii=False, indent=2, default=str)
|
||||
|
||||
|
||||
def _mcp_text_result(
|
||||
payload: Dict[str, Any], *, is_error: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
result: Dict[str, Any] = {
|
||||
"content": [{"type": "text", "text": _safe_json(payload)}]
|
||||
}
|
||||
if is_error:
|
||||
result["isError"] = True
|
||||
return result
|
||||
|
||||
|
||||
def _insight_to_dict(insight: WAFInsight) -> Dict[str, Any]:
|
||||
return asdict(insight)
|
||||
|
||||
|
||||
def _assessment_to_dict(assessment: ThreatAssessment) -> Dict[str, Any]:
|
||||
violations = []
|
||||
if assessment.analysis_result and getattr(
|
||||
assessment.analysis_result, "violations", None
|
||||
):
|
||||
violations = list(assessment.analysis_result.violations)
|
||||
|
||||
severity_counts = {"error": 0, "warning": 0, "info": 0}
|
||||
for v in violations:
|
||||
sev = getattr(v, "severity", "info")
|
||||
if sev in severity_counts:
|
||||
severity_counts[sev] += 1
|
||||
|
||||
return {
|
||||
"risk_score": assessment.risk_score,
|
||||
"risk_level": assessment.risk_level,
|
||||
"classification_summary": assessment.classification_summary,
|
||||
"recommended_actions": assessment.recommended_actions,
|
||||
"analysis": {
|
||||
"has_config_analysis": assessment.analysis_result is not None,
|
||||
"violations_total": len(violations),
|
||||
"violations_by_severity": severity_counts,
|
||||
},
|
||||
"has_threat_intel": assessment.threat_report is not None,
|
||||
"generated_at": str(assessment.generated_at),
|
||||
}
|
||||
|
||||
|
||||
TOOLS: List[Dict[str, Any]] = [
|
||||
{
|
||||
"name": "waf_capabilities",
|
||||
"description": "List available WAF Intelligence capabilities.",
|
||||
"inputSchema": {"type": "object", "properties": {}},
|
||||
},
|
||||
{
|
||||
"name": "analyze_waf",
|
||||
"description": "Analyze Terraform WAF file(s) and return curated insights (legacy alias for waf_analyze).",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file": {
|
||||
"type": "string",
|
||||
"description": "Single file path to analyze.",
|
||||
},
|
||||
"files": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "List of file paths or glob patterns to analyze.",
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"default": 3,
|
||||
"description": "Max insights per file.",
|
||||
},
|
||||
"severity_threshold": {
|
||||
"type": "string",
|
||||
"enum": ["info", "warning", "error"],
|
||||
"default": "warning",
|
||||
"description": "Minimum severity to include (alias for min_severity).",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "waf_analyze",
|
||||
"description": "Analyze Terraform WAF file(s) and return curated insights (requires file or files).",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file": {
|
||||
"type": "string",
|
||||
"description": "Single file path to analyze.",
|
||||
},
|
||||
"files": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "List of file paths or glob patterns to analyze.",
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"default": 3,
|
||||
"description": "Max insights per file.",
|
||||
},
|
||||
"min_severity": {
|
||||
"type": "string",
|
||||
"enum": ["info", "warning", "error"],
|
||||
"default": "warning",
|
||||
"description": "Minimum severity to include.",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "waf_assess",
|
||||
"description": "Run a broader assessment (optionally includes threat intel collection).",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"waf_config_path": {
|
||||
"type": "string",
|
||||
"description": "Path to Terraform WAF config (default: terraform/waf.tf).",
|
||||
},
|
||||
"include_threat_intel": {
|
||||
"type": "boolean",
|
||||
"default": False,
|
||||
"description": "If true, attempt to collect threat intel (may require network and credentials).",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "waf_generate_gitops_proposals",
|
||||
"description": "Generate GitOps-ready rule proposals (best-effort; requires threat intel to produce output).",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"waf_config_path": {
|
||||
"type": "string",
|
||||
"description": "Path to Terraform WAF config (default: terraform/waf.tf).",
|
||||
},
|
||||
"include_threat_intel": {
|
||||
"type": "boolean",
|
||||
"default": True,
|
||||
"description": "Attempt to collect threat intel before proposing rules.",
|
||||
},
|
||||
"max_proposals": {
|
||||
"type": "integer",
|
||||
"default": 5,
|
||||
"description": "Maximum proposals to generate.",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
class WafIntelligenceTools:
|
||||
def __init__(self) -> None:
|
||||
self.workspace_root = _cloudflare_root()
|
||||
self.repo_root = self.workspace_root.parent
|
||||
self.waf = WAFIntelligence(workspace_path=str(self.workspace_root))
|
||||
|
||||
def _resolve_path(self, raw: str) -> Path:
|
||||
path = Path(raw)
|
||||
if path.is_absolute():
|
||||
return path
|
||||
|
||||
candidates = [
|
||||
Path.cwd() / path,
|
||||
self.workspace_root / path,
|
||||
self.repo_root / path,
|
||||
]
|
||||
for candidate in candidates:
|
||||
if candidate.exists():
|
||||
return candidate
|
||||
return self.workspace_root / path
|
||||
|
||||
def waf_capabilities(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"ok": True,
|
||||
"summary": "WAF Intelligence capabilities.",
|
||||
"data": {"capabilities": self.waf.capabilities},
|
||||
"truncated": False,
|
||||
"next_steps": [
|
||||
"Call waf_analyze(file=..., limit=...) to analyze config.",
|
||||
"Call waf_assess(include_threat_intel=true) for a broader assessment.",
|
||||
],
|
||||
}
|
||||
|
||||
def waf_analyze(
|
||||
self,
|
||||
*,
|
||||
file: Optional[str] = None,
|
||||
files: Optional[List[str]] = None,
|
||||
limit: int = 3,
|
||||
min_severity: str = "warning",
|
||||
) -> Dict[str, Any]:
|
||||
paths: List[str] = []
|
||||
if files:
|
||||
for pattern in files:
|
||||
paths.extend(glob.glob(pattern))
|
||||
if file:
|
||||
paths.append(file)
|
||||
|
||||
seen = set()
|
||||
unique_paths: List[str] = []
|
||||
for p in paths:
|
||||
if p not in seen:
|
||||
seen.add(p)
|
||||
unique_paths.append(p)
|
||||
|
||||
if not unique_paths:
|
||||
return {
|
||||
"ok": False,
|
||||
"summary": "Provide 'file' or 'files' to analyze.",
|
||||
"truncated": False,
|
||||
"next_steps": ["Call waf_analyze(file='terraform/waf.tf')"],
|
||||
}
|
||||
|
||||
results: List[Dict[str, Any]] = []
|
||||
for p in unique_paths:
|
||||
path = self._resolve_path(p)
|
||||
if not path.exists():
|
||||
results.append(
|
||||
{
|
||||
"file": str(path),
|
||||
"ok": False,
|
||||
"summary": "File not found.",
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
insights = self.waf.analyze_and_recommend(
|
||||
str(path),
|
||||
limit=limit,
|
||||
min_severity=min_severity,
|
||||
)
|
||||
results.append(
|
||||
{
|
||||
"file": str(path),
|
||||
"ok": True,
|
||||
"insights": [_insight_to_dict(i) for i in insights],
|
||||
}
|
||||
)
|
||||
|
||||
ok = all(r.get("ok") for r in results)
|
||||
return {
|
||||
"ok": ok,
|
||||
"summary": f"Analyzed {len(results)} file(s).",
|
||||
"data": {"results": results},
|
||||
"truncated": False,
|
||||
"next_steps": [
|
||||
"Raise/lower min_severity or limit to tune output size.",
|
||||
],
|
||||
}
|
||||
|
||||
def waf_assess(
|
||||
self,
|
||||
*,
|
||||
waf_config_path: Optional[str] = None,
|
||||
include_threat_intel: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
waf_config_path_resolved = (
|
||||
str(self._resolve_path(waf_config_path)) if waf_config_path else None
|
||||
)
|
||||
assessment = self.waf.full_assessment(
|
||||
waf_config_path=waf_config_path_resolved,
|
||||
include_threat_intel=include_threat_intel,
|
||||
)
|
||||
return {
|
||||
"ok": True,
|
||||
"summary": "WAF assessment complete.",
|
||||
"data": _assessment_to_dict(assessment),
|
||||
"truncated": False,
|
||||
"next_steps": [
|
||||
"Call waf_generate_gitops_proposals(...) to draft Terraform rule proposals (best-effort).",
|
||||
],
|
||||
}
|
||||
|
||||
def waf_generate_gitops_proposals(
|
||||
self,
|
||||
*,
|
||||
waf_config_path: Optional[str] = None,
|
||||
include_threat_intel: bool = True,
|
||||
max_proposals: int = 5,
|
||||
) -> Dict[str, Any]:
|
||||
waf_config_path_resolved = (
|
||||
str(self._resolve_path(waf_config_path)) if waf_config_path else None
|
||||
)
|
||||
assessment = self.waf.full_assessment(
|
||||
waf_config_path=waf_config_path_resolved,
|
||||
include_threat_intel=include_threat_intel,
|
||||
)
|
||||
proposals = self.waf.generate_gitops_proposals(
|
||||
threat_report=assessment.threat_report,
|
||||
max_proposals=max_proposals,
|
||||
)
|
||||
return {
|
||||
"ok": True,
|
||||
"summary": f"Generated {len(proposals)} proposal(s).",
|
||||
"data": {
|
||||
"assessment": _assessment_to_dict(assessment),
|
||||
"proposals": proposals,
|
||||
},
|
||||
"truncated": False,
|
||||
"next_steps": [
|
||||
"If proposals are empty, enable threat intel and ensure required credentials/log sources exist.",
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
class StdioJsonRpc:
|
||||
def __init__(self) -> None:
|
||||
self._in = sys.stdin.buffer
|
||||
self._out = sys.stdout.buffer
|
||||
self._mode: str | None = None # "headers" | "line"
|
||||
|
||||
def read_message(self) -> Optional[Dict[str, Any]]:
|
||||
while True:
|
||||
if self._mode == "line":
|
||||
line = self._in.readline()
|
||||
if not line:
|
||||
return None
|
||||
raw = line.decode("utf-8", "replace").strip()
|
||||
if not raw:
|
||||
continue
|
||||
try:
|
||||
msg = json.loads(raw)
|
||||
except Exception:
|
||||
continue
|
||||
if isinstance(msg, dict):
|
||||
return msg
|
||||
continue
|
||||
|
||||
first = self._in.readline()
|
||||
if not first:
|
||||
return None
|
||||
|
||||
if first in (b"\r\n", b"\n"):
|
||||
continue
|
||||
|
||||
# Auto-detect newline-delimited JSON framing.
|
||||
if self._mode is None and first.lstrip().startswith(b"{"):
|
||||
try:
|
||||
msg = json.loads(first.decode("utf-8", "replace"))
|
||||
except Exception:
|
||||
msg = None
|
||||
if isinstance(msg, dict):
|
||||
self._mode = "line"
|
||||
return msg
|
||||
|
||||
headers: Dict[str, str] = {}
|
||||
try:
|
||||
text = first.decode("utf-8", "replace").strip()
|
||||
except Exception:
|
||||
continue
|
||||
if ":" not in text:
|
||||
continue
|
||||
k, v = text.split(":", 1)
|
||||
headers[k.lower().strip()] = v.strip()
|
||||
|
||||
while True:
|
||||
line = self._in.readline()
|
||||
if not line:
|
||||
return None
|
||||
if line in (b"\r\n", b"\n"):
|
||||
break
|
||||
try:
|
||||
text = line.decode("utf-8", "replace").strip()
|
||||
except Exception:
|
||||
continue
|
||||
if ":" not in text:
|
||||
continue
|
||||
k, v = text.split(":", 1)
|
||||
headers[k.lower().strip()] = v.strip()
|
||||
|
||||
if "content-length" not in headers:
|
||||
return None
|
||||
try:
|
||||
length = int(headers["content-length"])
|
||||
except ValueError:
|
||||
return None
|
||||
body = self._in.read(length)
|
||||
if not body:
|
||||
return None
|
||||
self._mode = "headers"
|
||||
msg = json.loads(body.decode("utf-8", "replace"))
|
||||
if isinstance(msg, dict):
|
||||
return msg
|
||||
return None
|
||||
|
||||
def write_message(self, message: Dict[str, Any]) -> None:
|
||||
if self._mode == "line":
|
||||
payload = json.dumps(
|
||||
message, ensure_ascii=False, separators=(",", ":"), default=str
|
||||
).encode("utf-8")
|
||||
self._out.write(payload + b"\n")
|
||||
self._out.flush()
|
||||
return
|
||||
|
||||
body = json.dumps(
|
||||
message, ensure_ascii=False, separators=(",", ":"), default=str
|
||||
).encode("utf-8")
|
||||
header = f"Content-Length: {len(body)}\r\n\r\n".encode("utf-8")
|
||||
self._out.write(header)
|
||||
self._out.write(body)
|
||||
self._out.flush()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
tools = WafIntelligenceTools()
|
||||
rpc = StdioJsonRpc()
|
||||
|
||||
handlers: Dict[str, Callable[[Dict[str, Any]], Dict[str, Any]]] = {
|
||||
"waf_capabilities": lambda a: tools.waf_capabilities(),
|
||||
"analyze_waf": lambda a: tools.waf_analyze(
|
||||
file=a.get("file"),
|
||||
files=a.get("files"),
|
||||
limit=int(a.get("limit", 3)),
|
||||
min_severity=str(a.get("severity_threshold", "warning")),
|
||||
),
|
||||
"waf_analyze": lambda a: tools.waf_analyze(**a),
|
||||
"waf_assess": lambda a: tools.waf_assess(**a),
|
||||
"waf_generate_gitops_proposals": lambda a: tools.waf_generate_gitops_proposals(
|
||||
**a
|
||||
),
|
||||
}
|
||||
|
||||
while True:
|
||||
msg = rpc.read_message()
|
||||
if msg is None:
|
||||
return
|
||||
|
||||
method = msg.get("method")
|
||||
msg_id = msg.get("id")
|
||||
params = msg.get("params") or {}
|
||||
|
||||
try:
|
||||
if method == "initialize":
|
||||
result = {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"serverInfo": {"name": "waf_intelligence", "version": "0.1.0"},
|
||||
"capabilities": {"tools": {}},
|
||||
}
|
||||
rpc.write_message({"jsonrpc": "2.0", "id": msg_id, "result": result})
|
||||
continue
|
||||
|
||||
if method == "tools/list":
|
||||
rpc.write_message(
|
||||
{"jsonrpc": "2.0", "id": msg_id, "result": {"tools": TOOLS}}
|
||||
)
|
||||
continue
|
||||
|
||||
if method == "tools/call":
|
||||
tool_name = str(params.get("name") or "")
|
||||
args = params.get("arguments") or {}
|
||||
|
||||
routing_action, shadow = layer0_entry(
|
||||
_shadow_query_repr(tool_name, args)
|
||||
)
|
||||
if routing_action != "HANDOFF_TO_LAYER1":
|
||||
rpc.write_message(
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"id": msg_id,
|
||||
"result": _mcp_text_result(
|
||||
_layer0_payload(routing_action, shadow), is_error=True
|
||||
),
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
handler = handlers.get(tool_name)
|
||||
if not handler:
|
||||
rpc.write_message(
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"id": msg_id,
|
||||
"result": _mcp_text_result(
|
||||
{
|
||||
"ok": False,
|
||||
"summary": f"Unknown tool: {tool_name}",
|
||||
"data": {"known_tools": sorted(handlers.keys())},
|
||||
"truncated": False,
|
||||
"next_steps": ["Call tools/list"],
|
||||
},
|
||||
is_error=True,
|
||||
),
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
payload = handler(args)
|
||||
is_error = (
|
||||
not bool(payload.get("ok", True))
|
||||
if isinstance(payload, dict)
|
||||
else False
|
||||
)
|
||||
rpc.write_message(
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"id": msg_id,
|
||||
"result": _mcp_text_result(payload, is_error=is_error),
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
# Ignore notifications.
|
||||
if msg_id is None:
|
||||
continue
|
||||
|
||||
rpc.write_message(
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"id": msg_id,
|
||||
"result": _mcp_text_result(
|
||||
{"ok": False, "summary": f"Unsupported method: {method}"},
|
||||
is_error=True,
|
||||
),
|
||||
}
|
||||
)
|
||||
except Exception as e: # noqa: BLE001
|
||||
if msg_id is not None:
|
||||
rpc.write_message(
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"id": msg_id,
|
||||
"result": _mcp_text_result(
|
||||
{"ok": False, "summary": f"fatal error: {e}"},
|
||||
is_error=True,
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _shadow_query_repr(tool_name: str, tool_args: Dict[str, Any]) -> str:
|
||||
if tool_name == "waf_capabilities":
|
||||
return "List WAF Intelligence capabilities."
|
||||
try:
|
||||
return f"{tool_name}: {json.dumps(tool_args, sort_keys=True, default=str)}"
|
||||
except Exception:
|
||||
return f"{tool_name}: {str(tool_args)}"
|
||||
|
||||
|
||||
def _layer0_payload(routing_action: str, shadow: ShadowEvalResult) -> Dict[str, Any]:
|
||||
if routing_action == "FAIL_CLOSED":
|
||||
return {"ok": False, "summary": "Layer 0: cannot comply with this request."}
|
||||
if routing_action == "HANDOFF_TO_GUARDRAILS":
|
||||
reason = shadow.reason or "governance_violation"
|
||||
return {
|
||||
"ok": False,
|
||||
"summary": f"Layer 0: governance violation detected ({reason}).",
|
||||
}
|
||||
if routing_action == "PROMPT_FOR_CLARIFICATION":
|
||||
return {
|
||||
"ok": False,
|
||||
"summary": "Layer 0: request is ambiguous. Please clarify and retry.",
|
||||
}
|
||||
return {"ok": False, "summary": "Layer 0: unrecognized routing action; refusing."}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -6,27 +6,26 @@ from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from mcp.waf_intelligence.analyzer import AnalysisResult, RuleViolation, WAFRuleAnalyzer
|
||||
from mcp.waf_intelligence.compliance import ComplianceMapper, FrameworkMapping
|
||||
from mcp.waf_intelligence.generator import GeneratedRule, WAFRuleGenerator
|
||||
from .analyzer import AnalysisResult, RuleViolation, WAFRuleAnalyzer
|
||||
from .compliance import ComplianceMapper, FrameworkMapping
|
||||
from .generator import GeneratedRule, WAFRuleGenerator
|
||||
|
||||
# Optional advanced modules (Phase 7)
|
||||
try:
|
||||
from mcp.waf_intelligence.threat_intel import (
|
||||
from .threat_intel import (
|
||||
ThreatIntelCollector,
|
||||
ThreatIntelReport,
|
||||
ThreatIndicator,
|
||||
)
|
||||
|
||||
_HAS_THREAT_INTEL = True
|
||||
except ImportError:
|
||||
_HAS_THREAT_INTEL = False
|
||||
ThreatIntelCollector = None
|
||||
|
||||
try:
|
||||
from mcp.waf_intelligence.classifier import (
|
||||
ThreatClassifier,
|
||||
ClassificationResult,
|
||||
)
|
||||
from .classifier import ThreatClassifier
|
||||
|
||||
_HAS_CLASSIFIER = True
|
||||
except ImportError:
|
||||
_HAS_CLASSIFIER = False
|
||||
@@ -45,14 +44,14 @@ class WAFInsight:
|
||||
@dataclass
|
||||
class ThreatAssessment:
|
||||
"""Phase 7: Comprehensive threat assessment result."""
|
||||
|
||||
|
||||
analysis_result: Optional[AnalysisResult] = None
|
||||
threat_report: Optional[Any] = None # ThreatIntelReport when available
|
||||
classification_summary: Dict[str, int] = field(default_factory=dict)
|
||||
risk_score: float = 0.0
|
||||
recommended_actions: List[str] = field(default_factory=list)
|
||||
generated_at: datetime = field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
@property
|
||||
def risk_level(self) -> str:
|
||||
if self.risk_score >= 0.8:
|
||||
@@ -81,22 +80,22 @@ class WAFIntelligence:
|
||||
enable_ml_classifier: bool = True,
|
||||
) -> None:
|
||||
self.workspace = Path(workspace_path) if workspace_path else Path.cwd()
|
||||
|
||||
|
||||
# Core components
|
||||
self.analyzer = WAFRuleAnalyzer()
|
||||
self.generator = WAFRuleGenerator()
|
||||
self.mapper = ComplianceMapper()
|
||||
|
||||
|
||||
# Phase 7 components (optional)
|
||||
self.threat_intel: Optional[Any] = None
|
||||
self.classifier: Optional[Any] = None
|
||||
|
||||
|
||||
if enable_threat_intel and _HAS_THREAT_INTEL:
|
||||
try:
|
||||
self.threat_intel = ThreatIntelCollector()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
if enable_ml_classifier and _HAS_CLASSIFIER:
|
||||
try:
|
||||
self.classifier = ThreatClassifier()
|
||||
@@ -149,24 +148,24 @@ class WAFIntelligence:
|
||||
) -> Optional[Any]:
|
||||
"""
|
||||
Collect threat intelligence from logs and external feeds.
|
||||
|
||||
|
||||
Args:
|
||||
log_paths: Paths to Cloudflare log files
|
||||
max_indicators: Maximum indicators to collect
|
||||
|
||||
|
||||
Returns:
|
||||
ThreatIntelReport or None if unavailable
|
||||
"""
|
||||
if not self.threat_intel:
|
||||
return None
|
||||
|
||||
|
||||
# Default log paths
|
||||
if log_paths is None:
|
||||
log_paths = [
|
||||
str(self.workspace / "logs"),
|
||||
"/var/log/cloudflare",
|
||||
]
|
||||
|
||||
|
||||
return self.threat_intel.collect(
|
||||
log_paths=log_paths,
|
||||
max_indicators=max_indicators,
|
||||
@@ -175,16 +174,16 @@ class WAFIntelligence:
|
||||
def classify_threat(self, payload: str) -> Optional[Any]:
|
||||
"""
|
||||
Classify a payload using ML classifier.
|
||||
|
||||
|
||||
Args:
|
||||
payload: Request payload to classify
|
||||
|
||||
|
||||
Returns:
|
||||
ClassificationResult or None
|
||||
"""
|
||||
if not self.classifier:
|
||||
return None
|
||||
|
||||
|
||||
return self.classifier.classify(payload)
|
||||
|
||||
def full_assessment(
|
||||
@@ -195,51 +194,52 @@ class WAFIntelligence:
|
||||
) -> ThreatAssessment:
|
||||
"""
|
||||
Phase 7: Perform comprehensive threat assessment.
|
||||
|
||||
|
||||
Combines:
|
||||
- WAF configuration analysis
|
||||
- Threat intelligence collection
|
||||
- ML classification summary
|
||||
- Risk scoring
|
||||
|
||||
|
||||
Args:
|
||||
waf_config_path: Path to WAF Terraform file
|
||||
log_paths: Paths to log files
|
||||
include_threat_intel: Whether to collect threat intel
|
||||
|
||||
|
||||
Returns:
|
||||
ThreatAssessment with full analysis results
|
||||
"""
|
||||
assessment = ThreatAssessment()
|
||||
risk_factors: List[float] = []
|
||||
recommendations: List[str] = []
|
||||
|
||||
|
||||
# 1. Analyze WAF configuration
|
||||
if waf_config_path is None:
|
||||
waf_config_path = str(self.workspace / "terraform" / "waf.tf")
|
||||
|
||||
|
||||
if Path(waf_config_path).exists():
|
||||
assessment.analysis_result = self.analyzer.analyze_file(
|
||||
waf_config_path,
|
||||
min_severity="info",
|
||||
)
|
||||
|
||||
|
||||
# Calculate risk from violations
|
||||
severity_weights = {"error": 0.8, "warning": 0.5, "info": 0.2}
|
||||
for violation in assessment.analysis_result.violations:
|
||||
weight = severity_weights.get(violation.severity, 0.3)
|
||||
risk_factors.append(weight)
|
||||
|
||||
|
||||
# Generate recommendations
|
||||
critical_count = sum(
|
||||
1 for v in assessment.analysis_result.violations
|
||||
1
|
||||
for v in assessment.analysis_result.violations
|
||||
if v.severity == "error"
|
||||
)
|
||||
if critical_count > 0:
|
||||
recommendations.append(
|
||||
f"🔴 Fix {critical_count} critical WAF configuration issues"
|
||||
)
|
||||
|
||||
|
||||
# 2. Collect threat intelligence
|
||||
if include_threat_intel and self.threat_intel:
|
||||
try:
|
||||
@@ -247,52 +247,55 @@ class WAFIntelligence:
|
||||
log_paths=log_paths,
|
||||
max_indicators=50,
|
||||
)
|
||||
|
||||
|
||||
if assessment.threat_report:
|
||||
indicators = assessment.threat_report.indicators
|
||||
|
||||
|
||||
# Count by severity
|
||||
severity_counts = {"critical": 0, "high": 0, "medium": 0, "low": 0}
|
||||
for ind in indicators:
|
||||
sev = getattr(ind, "severity", "low")
|
||||
severity_counts[sev] = severity_counts.get(sev, 0) + 1
|
||||
|
||||
|
||||
# Add to classification summary
|
||||
assessment.classification_summary["threat_indicators"] = len(indicators)
|
||||
assessment.classification_summary["threat_indicators"] = len(
|
||||
indicators
|
||||
)
|
||||
assessment.classification_summary.update(severity_counts)
|
||||
|
||||
|
||||
# Calculate threat intel risk
|
||||
if indicators:
|
||||
critical_ratio = severity_counts["critical"] / len(indicators)
|
||||
high_ratio = severity_counts["high"] / len(indicators)
|
||||
risk_factors.append(critical_ratio * 0.9 + high_ratio * 0.7)
|
||||
|
||||
|
||||
if severity_counts["critical"] > 0:
|
||||
recommendations.append(
|
||||
f"🚨 Block {severity_counts['critical']} critical threat IPs immediately"
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
# 3. ML classification summary (from any collected data)
|
||||
if self.classifier and assessment.threat_report:
|
||||
try:
|
||||
attack_types = {"sqli": 0, "xss": 0, "rce": 0, "clean": 0, "unknown": 0}
|
||||
|
||||
|
||||
indicators = assessment.threat_report.indicators
|
||||
pattern_indicators = [
|
||||
i for i in indicators
|
||||
i
|
||||
for i in indicators
|
||||
if getattr(i, "indicator_type", "") == "pattern"
|
||||
]
|
||||
|
||||
|
||||
for ind in pattern_indicators[:20]: # Sample first 20
|
||||
result = self.classifier.classify(ind.value)
|
||||
if result:
|
||||
label = result.label
|
||||
attack_types[label] = attack_types.get(label, 0) + 1
|
||||
|
||||
|
||||
assessment.classification_summary["ml_classifications"] = attack_types
|
||||
|
||||
|
||||
# Add ML risk factor
|
||||
dangerous = attack_types.get("sqli", 0) + attack_types.get("rce", 0)
|
||||
if dangerous > 5:
|
||||
@@ -302,15 +305,17 @@ class WAFIntelligence:
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
# 4. Calculate final risk score
|
||||
if risk_factors:
|
||||
assessment.risk_score = min(1.0, sum(risk_factors) / max(len(risk_factors), 1))
|
||||
assessment.risk_score = min(
|
||||
1.0, sum(risk_factors) / max(len(risk_factors), 1)
|
||||
)
|
||||
else:
|
||||
assessment.risk_score = 0.3 # Baseline risk
|
||||
|
||||
|
||||
assessment.recommended_actions = recommendations
|
||||
|
||||
|
||||
return assessment
|
||||
|
||||
def generate_gitops_proposals(
|
||||
@@ -320,42 +325,44 @@ class WAFIntelligence:
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Generate GitOps-ready rule proposals.
|
||||
|
||||
|
||||
Args:
|
||||
threat_report: ThreatIntelReport to use
|
||||
max_proposals: Maximum proposals to generate
|
||||
|
||||
|
||||
Returns:
|
||||
List of proposal dicts ready for MR creation
|
||||
"""
|
||||
proposals: List[Dict[str, Any]] = []
|
||||
|
||||
|
||||
if not threat_report:
|
||||
return proposals
|
||||
|
||||
|
||||
try:
|
||||
# Import proposer dynamically
|
||||
from gitops.waf_rule_proposer import WAFRuleProposer
|
||||
|
||||
|
||||
proposer = WAFRuleProposer(workspace_path=str(self.workspace))
|
||||
batch = proposer.generate_proposals(
|
||||
threat_report=threat_report,
|
||||
max_proposals=max_proposals,
|
||||
)
|
||||
|
||||
|
||||
for proposal in batch.proposals:
|
||||
proposals.append({
|
||||
"name": proposal.rule_name,
|
||||
"type": proposal.rule_type,
|
||||
"severity": proposal.severity,
|
||||
"confidence": proposal.confidence,
|
||||
"terraform": proposal.terraform_code,
|
||||
"justification": proposal.justification,
|
||||
"auto_deploy": proposal.auto_deploy_eligible,
|
||||
})
|
||||
proposals.append(
|
||||
{
|
||||
"name": proposal.rule_name,
|
||||
"type": proposal.rule_type,
|
||||
"severity": proposal.severity,
|
||||
"confidence": proposal.confidence,
|
||||
"terraform": proposal.terraform_code,
|
||||
"justification": proposal.justification,
|
||||
"auto_deploy": proposal.auto_deploy_eligible,
|
||||
}
|
||||
)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
return proposals
|
||||
|
||||
@property
|
||||
|
||||
326
mcp/waf_intelligence/server.py
Executable file → Normal file
326
mcp/waf_intelligence/server.py
Executable file → Normal file
@@ -1,326 +1,14 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
WAF Intelligence MCP Server for VS Code Copilot.
|
||||
from __future__ import annotations
|
||||
|
||||
This implements the Model Context Protocol (MCP) stdio interface
|
||||
so VS Code can communicate with your WAF Intelligence system.
|
||||
"""
|
||||
Deprecated entrypoint kept for older editor configs.
|
||||
|
||||
Use `python3 -m mcp.waf_intelligence.mcp_server` (or `waf_intel_mcp.py`) instead.
|
||||
"""
|
||||
|
||||
import json
|
||||
import sys
|
||||
from typing import Any
|
||||
|
||||
# Add parent to path for imports
|
||||
sys.path.insert(0, '/Users/sovereign/Desktop/CLOUDFLARE')
|
||||
|
||||
from mcp.waf_intelligence.orchestrator import WAFIntelligence
|
||||
from mcp.waf_intelligence.analyzer import WAFRuleAnalyzer
|
||||
from layer0 import layer0_entry
|
||||
from layer0.shadow_classifier import ShadowEvalResult
|
||||
|
||||
|
||||
class WAFIntelligenceMCPServer:
|
||||
"""MCP Server wrapper for WAF Intelligence."""
|
||||
|
||||
def __init__(self):
|
||||
self.waf = WAFIntelligence()
|
||||
self.analyzer = WAFRuleAnalyzer()
|
||||
|
||||
def get_capabilities(self) -> dict:
|
||||
"""Return server capabilities."""
|
||||
return {
|
||||
"tools": [
|
||||
{
|
||||
"name": "waf_analyze",
|
||||
"description": "Analyze WAF logs and detect attack patterns",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"log_file": {
|
||||
"type": "string",
|
||||
"description": "Path to WAF log file (optional)"
|
||||
},
|
||||
"zone_id": {
|
||||
"type": "string",
|
||||
"description": "Cloudflare zone ID (optional)"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "waf_assess",
|
||||
"description": "Run full security assessment with threat intel and ML classification",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"zone_id": {
|
||||
"type": "string",
|
||||
"description": "Cloudflare zone ID"
|
||||
}
|
||||
},
|
||||
"required": ["zone_id"]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "waf_generate_rules",
|
||||
"description": "Generate Terraform WAF rules from threat intelligence",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"zone_id": {
|
||||
"type": "string",
|
||||
"description": "Cloudflare zone ID"
|
||||
},
|
||||
"min_confidence": {
|
||||
"type": "number",
|
||||
"description": "Minimum confidence threshold (0-1)",
|
||||
"default": 0.7
|
||||
}
|
||||
},
|
||||
"required": ["zone_id"]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "waf_capabilities",
|
||||
"description": "List available WAF Intelligence capabilities",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
def handle_tool_call(self, name: str, arguments: dict) -> dict:
|
||||
"""Handle a tool invocation."""
|
||||
try:
|
||||
if name == "waf_capabilities":
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": json.dumps({
|
||||
"capabilities": self.waf.capabilities,
|
||||
"status": "operational"
|
||||
}, indent=2)
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
elif name == "waf_analyze":
|
||||
log_file = arguments.get("log_file")
|
||||
zone_id = arguments.get("zone_id")
|
||||
|
||||
if log_file:
|
||||
result = self.analyzer.analyze_log_file(log_file)
|
||||
else:
|
||||
result = {
|
||||
"message": "No log file provided. Use zone_id for live analysis.",
|
||||
"capabilities": self.waf.capabilities
|
||||
}
|
||||
|
||||
return {
|
||||
"content": [
|
||||
{"type": "text", "text": json.dumps(result, indent=2, default=str)}
|
||||
]
|
||||
}
|
||||
|
||||
elif name == "waf_assess":
|
||||
zone_id = arguments.get("zone_id")
|
||||
# full_assessment uses workspace paths, not zone_id
|
||||
assessment = self.waf.full_assessment(
|
||||
include_threat_intel=True
|
||||
)
|
||||
# Build result from ThreatAssessment dataclass
|
||||
result = {
|
||||
"zone_id": zone_id,
|
||||
"risk_score": assessment.risk_score,
|
||||
"risk_level": assessment.risk_level,
|
||||
"classification_summary": assessment.classification_summary,
|
||||
"recommended_actions": assessment.recommended_actions[:10], # Top 10
|
||||
"has_analysis": assessment.analysis_result is not None,
|
||||
"has_threat_intel": assessment.threat_report is not None,
|
||||
"generated_at": str(assessment.generated_at)
|
||||
}
|
||||
|
||||
return {
|
||||
"content": [
|
||||
{"type": "text", "text": json.dumps(result, indent=2, default=str)}
|
||||
]
|
||||
}
|
||||
|
||||
elif name == "waf_generate_rules":
|
||||
zone_id = arguments.get("zone_id")
|
||||
min_confidence = arguments.get("min_confidence", 0.7)
|
||||
|
||||
# Generate proposals (doesn't use zone_id directly)
|
||||
proposals = self.waf.generate_gitops_proposals(
|
||||
max_proposals=5
|
||||
)
|
||||
|
||||
result = {
|
||||
"zone_id": zone_id,
|
||||
"min_confidence": min_confidence,
|
||||
"proposals_count": len(proposals),
|
||||
"proposals": proposals
|
||||
}
|
||||
|
||||
return {
|
||||
"content": [
|
||||
{"type": "text", "text": json.dumps(result, indent=2, default=str) if proposals else "No rules generated (no threat data available)"}
|
||||
]
|
||||
}
|
||||
|
||||
else:
|
||||
return {
|
||||
"content": [
|
||||
{"type": "text", "text": f"Unknown tool: {name}"}
|
||||
],
|
||||
"isError": True
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"content": [
|
||||
{"type": "text", "text": f"Error: {str(e)}"}
|
||||
],
|
||||
"isError": True
|
||||
}
|
||||
|
||||
def run(self):
|
||||
"""Run the MCP server (stdio mode)."""
|
||||
# Send server info
|
||||
server_info = {
|
||||
"jsonrpc": "2.0",
|
||||
"method": "initialized",
|
||||
"params": {
|
||||
"serverInfo": {
|
||||
"name": "waf-intelligence",
|
||||
"version": "1.0.0"
|
||||
},
|
||||
"capabilities": self.get_capabilities()
|
||||
}
|
||||
}
|
||||
|
||||
# Main loop - read JSON-RPC messages from stdin
|
||||
for line in sys.stdin:
|
||||
try:
|
||||
message = json.loads(line.strip())
|
||||
|
||||
if message.get("method") == "initialize":
|
||||
response = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": message.get("id"),
|
||||
"result": {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"serverInfo": {
|
||||
"name": "waf-intelligence",
|
||||
"version": "1.0.0"
|
||||
},
|
||||
"capabilities": {
|
||||
"tools": {}
|
||||
}
|
||||
}
|
||||
}
|
||||
print(json.dumps(response), flush=True)
|
||||
|
||||
elif message.get("method") == "tools/list":
|
||||
response = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": message.get("id"),
|
||||
"result": self.get_capabilities()
|
||||
}
|
||||
print(json.dumps(response), flush=True)
|
||||
|
||||
elif message.get("method") == "tools/call":
|
||||
params = message.get("params", {})
|
||||
tool_name = params.get("name")
|
||||
tool_args = params.get("arguments", {})
|
||||
|
||||
# Layer 0: pre-boot Shadow Eval gate before handling tool calls.
|
||||
routing_action, shadow = layer0_entry(_shadow_query_repr(tool_name, tool_args))
|
||||
if routing_action != "HANDOFF_TO_LAYER1":
|
||||
response = _layer0_mcp_response(routing_action, shadow, message.get("id"))
|
||||
print(json.dumps(response), flush=True)
|
||||
continue
|
||||
|
||||
result = self.handle_tool_call(tool_name, tool_args)
|
||||
|
||||
response = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": message.get("id"),
|
||||
"result": result
|
||||
}
|
||||
print(json.dumps(response), flush=True)
|
||||
|
||||
elif message.get("method") == "notifications/initialized":
|
||||
# Client acknowledged initialization
|
||||
pass
|
||||
|
||||
else:
|
||||
# Unknown method
|
||||
response = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": message.get("id"),
|
||||
"error": {
|
||||
"code": -32601,
|
||||
"message": f"Method not found: {message.get('method')}"
|
||||
}
|
||||
}
|
||||
print(json.dumps(response), flush=True)
|
||||
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
except Exception as e:
|
||||
error_response = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": None,
|
||||
"error": {
|
||||
"code": -32603,
|
||||
"message": str(e)
|
||||
}
|
||||
}
|
||||
print(json.dumps(error_response), flush=True)
|
||||
|
||||
from .mcp_server import main
|
||||
|
||||
if __name__ == "__main__":
|
||||
server = WAFIntelligenceMCPServer()
|
||||
server.run()
|
||||
main()
|
||||
|
||||
|
||||
def _shadow_query_repr(tool_name: str, tool_args: dict) -> str:
|
||||
"""Build a textual representation of the tool call for Layer 0 classification."""
|
||||
try:
|
||||
return f"{tool_name}: {json.dumps(tool_args, sort_keys=True)}"
|
||||
except TypeError:
|
||||
return f"{tool_name}: {str(tool_args)}"
|
||||
|
||||
|
||||
def _layer0_mcp_response(routing_action: str, shadow: ShadowEvalResult, msg_id: Any) -> dict:
|
||||
"""
|
||||
Map Layer 0 outcomes to MCP responses.
|
||||
Catastrophic/forbidden/ambiguous short-circuit with minimal disclosure.
|
||||
"""
|
||||
base = {"jsonrpc": "2.0", "id": msg_id}
|
||||
|
||||
if routing_action == "FAIL_CLOSED":
|
||||
base["error"] = {"code": -32000, "message": "Layer 0: cannot comply with this request."}
|
||||
return base
|
||||
|
||||
if routing_action == "HANDOFF_TO_GUARDRAILS":
|
||||
reason = shadow.reason or "governance_violation"
|
||||
base["error"] = {
|
||||
"code": -32001,
|
||||
"message": f"Layer 0: governance violation detected ({reason}).",
|
||||
}
|
||||
return base
|
||||
|
||||
if routing_action == "PROMPT_FOR_CLARIFICATION":
|
||||
base["error"] = {
|
||||
"code": -32002,
|
||||
"message": "Layer 0: request is ambiguous. Please clarify and retry.",
|
||||
}
|
||||
return base
|
||||
|
||||
base["error"] = {"code": -32099, "message": "Layer 0: unrecognized routing action; refusing."}
|
||||
return base
|
||||
|
||||
Reference in New Issue
Block a user