Files
Vault Sovereign f0b8d962de
Some checks failed
WAF Intelligence Guardrail / waf-intel (push) Waiting to run
Cloudflare Registry Validation / validate-registry (push) Has been cancelled
chore: pre-migration snapshot
Layer0, MCP servers, Terraform consolidation
2025-12-27 01:52:27 +00:00

387 lines
13 KiB
Python

from __future__ import annotations
import asyncio
import json
import os
import sys
from typing import Any, Callable, Dict, List, Optional
from layer0 import layer0_entry
from layer0.shadow_classifier import ShadowEvalResult
from .tool import OracleAnswerTool
MAX_BYTES_DEFAULT = 32_000
def _max_bytes() -> int:
raw = (os.getenv("VM_MCP_MAX_BYTES") or "").strip()
if not raw:
return MAX_BYTES_DEFAULT
try:
return max(4_096, int(raw))
except ValueError:
return MAX_BYTES_DEFAULT
def _redact(obj: Any) -> Any:
sensitive_keys = ("token", "secret", "password", "private", "key", "certificate")
if isinstance(obj, dict):
out: Dict[str, Any] = {}
for k, v in obj.items():
if any(s in str(k).lower() for s in sensitive_keys):
out[k] = "<REDACTED>"
else:
out[k] = _redact(v)
return out
if isinstance(obj, list):
return [_redact(v) for v in obj]
if isinstance(obj, str):
if obj.startswith("ghp_") or obj.startswith("github_pat_"):
return "<REDACTED>"
return obj
return obj
def _safe_json(payload: Dict[str, Any]) -> str:
payload = _redact(payload)
raw = json.dumps(payload, ensure_ascii=False, separators=(",", ":"), default=str)
if len(raw.encode("utf-8")) <= _max_bytes():
return json.dumps(payload, ensure_ascii=False, indent=2, default=str)
truncated = {
"ok": payload.get("ok", True),
"truncated": True,
"summary": payload.get("summary", "Response exceeded max size; truncated."),
"next_steps": payload.get(
"next_steps",
["request narrower outputs (e.g., fewer frameworks or shorter question)"],
),
}
return json.dumps(truncated, ensure_ascii=False, indent=2, default=str)
def _mcp_text_result(
payload: Dict[str, Any], *, is_error: bool = False
) -> Dict[str, Any]:
result: Dict[str, Any] = {
"content": [{"type": "text", "text": _safe_json(payload)}]
}
if is_error:
result["isError"] = True
return result
TOOLS: List[Dict[str, Any]] = [
{
"name": "oracle_answer",
"description": "Answer a compliance/security question (optionally via NVIDIA LLM) and map to frameworks.",
"inputSchema": {
"type": "object",
"properties": {
"question": {
"type": "string",
"description": "The question to answer.",
},
"frameworks": {
"type": "array",
"items": {"type": "string"},
"description": "Frameworks to reference (e.g., ['NIST-CSF','ISO-27001','GDPR']).",
},
"mode": {
"type": "string",
"enum": ["strict", "advisory"],
"default": "strict",
"description": "strict=conservative, advisory=exploratory.",
},
"local_only": {
"type": "boolean",
"description": "If true, skip NVIDIA API calls (uses local-only mode). Defaults to true when NVIDIA_API_KEY is missing.",
},
},
"required": ["question"],
},
}
]
class OracleAnswerTools:
async def oracle_answer(
self,
*,
question: str,
frameworks: Optional[List[str]] = None,
mode: str = "strict",
local_only: Optional[bool] = None,
) -> Dict[str, Any]:
routing_action, shadow = layer0_entry(question)
if routing_action != "HANDOFF_TO_LAYER1":
return _layer0_payload(routing_action, shadow)
local_only_use = (
bool(local_only)
if local_only is not None
else not bool((os.getenv("NVIDIA_API_KEY") or "").strip())
)
try:
tool = OracleAnswerTool(
default_frameworks=frameworks,
use_local_only=local_only_use,
)
except Exception as e: # noqa: BLE001
return {
"ok": False,
"summary": str(e),
"data": {
"local_only": local_only_use,
"has_nvidia_api_key": bool(
(os.getenv("NVIDIA_API_KEY") or "").strip()
),
},
"truncated": False,
"next_steps": [
"Set NVIDIA_API_KEY to enable live answers",
"Or call oracle_answer(local_only=true, ...)",
],
}
resp = await tool.answer(question=question, frameworks=frameworks, mode=mode)
return {
"ok": True,
"summary": "Oracle answer generated.",
"data": {
"question": question,
"mode": mode,
"frameworks": frameworks or tool.default_frameworks,
"local_only": local_only_use,
"model": resp.model,
"answer": resp.answer,
"framework_hits": resp.framework_hits,
"reasoning": resp.reasoning,
},
"truncated": False,
"next_steps": [
"If the answer is incomplete, add more specifics to the question or include more frameworks.",
],
}
class StdioJsonRpc:
def __init__(self) -> None:
self._in = sys.stdin.buffer
self._out = sys.stdout.buffer
self._mode: str | None = None # "headers" | "line"
def read_message(self) -> Optional[Dict[str, Any]]:
while True:
if self._mode == "line":
line = self._in.readline()
if not line:
return None
raw = line.decode("utf-8", "replace").strip()
if not raw:
continue
try:
msg = json.loads(raw)
except Exception:
continue
if isinstance(msg, dict):
return msg
continue
first = self._in.readline()
if not first:
return None
if first in (b"\r\n", b"\n"):
continue
# Auto-detect newline-delimited JSON framing.
if self._mode is None and first.lstrip().startswith(b"{"):
try:
msg = json.loads(first.decode("utf-8", "replace"))
except Exception:
msg = None
if isinstance(msg, dict):
self._mode = "line"
return msg
headers: Dict[str, str] = {}
try:
text = first.decode("utf-8", "replace").strip()
except Exception:
continue
if ":" not in text:
continue
k, v = text.split(":", 1)
headers[k.lower().strip()] = v.strip()
while True:
line = self._in.readline()
if not line:
return None
if line in (b"\r\n", b"\n"):
break
try:
text = line.decode("utf-8", "replace").strip()
except Exception:
continue
if ":" not in text:
continue
k, v = text.split(":", 1)
headers[k.lower().strip()] = v.strip()
if "content-length" not in headers:
return None
try:
length = int(headers["content-length"])
except ValueError:
return None
body = self._in.read(length)
if not body:
return None
self._mode = "headers"
msg = json.loads(body.decode("utf-8", "replace"))
if isinstance(msg, dict):
return msg
return None
def write_message(self, message: Dict[str, Any]) -> None:
if self._mode == "line":
payload = json.dumps(
message, ensure_ascii=False, separators=(",", ":"), default=str
).encode("utf-8")
self._out.write(payload + b"\n")
self._out.flush()
return
body = json.dumps(
message, ensure_ascii=False, separators=(",", ":"), default=str
).encode("utf-8")
header = f"Content-Length: {len(body)}\r\n\r\n".encode("utf-8")
self._out.write(header)
self._out.write(body)
self._out.flush()
def main() -> None:
tools = OracleAnswerTools()
rpc = StdioJsonRpc()
handlers: Dict[str, Callable[[Dict[str, Any]], Any]] = {
"oracle_answer": lambda a: tools.oracle_answer(**a),
}
while True:
msg = rpc.read_message()
if msg is None:
return
method = msg.get("method")
msg_id = msg.get("id")
params = msg.get("params") or {}
try:
if method == "initialize":
result = {
"protocolVersion": "2024-11-05",
"serverInfo": {"name": "oracle_answer", "version": "0.1.0"},
"capabilities": {"tools": {}},
}
rpc.write_message({"jsonrpc": "2.0", "id": msg_id, "result": result})
continue
if method == "tools/list":
rpc.write_message(
{"jsonrpc": "2.0", "id": msg_id, "result": {"tools": TOOLS}}
)
continue
if method == "tools/call":
tool_name = str(params.get("name") or "")
args = params.get("arguments") or {}
handler = handlers.get(tool_name)
if not handler:
rpc.write_message(
{
"jsonrpc": "2.0",
"id": msg_id,
"result": _mcp_text_result(
{
"ok": False,
"summary": f"Unknown tool: {tool_name}",
"data": {"known_tools": sorted(handlers.keys())},
"truncated": False,
"next_steps": ["Call tools/list"],
},
is_error=True,
),
}
)
continue
payload = asyncio.run(handler(args)) # type: ignore[arg-type]
is_error = (
not bool(payload.get("ok", True))
if isinstance(payload, dict)
else False
)
rpc.write_message(
{
"jsonrpc": "2.0",
"id": msg_id,
"result": _mcp_text_result(payload, is_error=is_error),
}
)
continue
# Ignore notifications.
if msg_id is None:
continue
rpc.write_message(
{
"jsonrpc": "2.0",
"id": msg_id,
"result": _mcp_text_result(
{"ok": False, "summary": f"Unsupported method: {method}"},
is_error=True,
),
}
)
except Exception as e: # noqa: BLE001
if msg_id is not None:
rpc.write_message(
{
"jsonrpc": "2.0",
"id": msg_id,
"result": _mcp_text_result(
{"ok": False, "summary": f"fatal error: {e}"},
is_error=True,
),
}
)
def _layer0_payload(routing_action: str, shadow: ShadowEvalResult) -> Dict[str, Any]:
if routing_action == "FAIL_CLOSED":
return {"ok": False, "summary": "Layer 0: cannot comply with this request."}
if routing_action == "HANDOFF_TO_GUARDRAILS":
reason = shadow.reason or "governance_violation"
return {
"ok": False,
"summary": f"Layer 0: governance violation detected ({reason}).",
}
if routing_action == "PROMPT_FOR_CLARIFICATION":
return {
"ok": False,
"summary": "Layer 0: request is ambiguous. Please clarify and retry.",
}
return {"ok": False, "summary": "Layer 0: unrecognized routing action; refusing."}
if __name__ == "__main__":
main()