387 lines
13 KiB
Python
387 lines
13 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import os
|
|
import sys
|
|
from typing import Any, Callable, Dict, List, Optional
|
|
|
|
from layer0 import layer0_entry
|
|
from layer0.shadow_classifier import ShadowEvalResult
|
|
|
|
from .tool import OracleAnswerTool
|
|
|
|
MAX_BYTES_DEFAULT = 32_000
|
|
|
|
|
|
def _max_bytes() -> int:
|
|
raw = (os.getenv("VM_MCP_MAX_BYTES") or "").strip()
|
|
if not raw:
|
|
return MAX_BYTES_DEFAULT
|
|
try:
|
|
return max(4_096, int(raw))
|
|
except ValueError:
|
|
return MAX_BYTES_DEFAULT
|
|
|
|
|
|
def _redact(obj: Any) -> Any:
|
|
sensitive_keys = ("token", "secret", "password", "private", "key", "certificate")
|
|
|
|
if isinstance(obj, dict):
|
|
out: Dict[str, Any] = {}
|
|
for k, v in obj.items():
|
|
if any(s in str(k).lower() for s in sensitive_keys):
|
|
out[k] = "<REDACTED>"
|
|
else:
|
|
out[k] = _redact(v)
|
|
return out
|
|
if isinstance(obj, list):
|
|
return [_redact(v) for v in obj]
|
|
if isinstance(obj, str):
|
|
if obj.startswith("ghp_") or obj.startswith("github_pat_"):
|
|
return "<REDACTED>"
|
|
return obj
|
|
return obj
|
|
|
|
|
|
def _safe_json(payload: Dict[str, Any]) -> str:
|
|
payload = _redact(payload)
|
|
raw = json.dumps(payload, ensure_ascii=False, separators=(",", ":"), default=str)
|
|
if len(raw.encode("utf-8")) <= _max_bytes():
|
|
return json.dumps(payload, ensure_ascii=False, indent=2, default=str)
|
|
|
|
truncated = {
|
|
"ok": payload.get("ok", True),
|
|
"truncated": True,
|
|
"summary": payload.get("summary", "Response exceeded max size; truncated."),
|
|
"next_steps": payload.get(
|
|
"next_steps",
|
|
["request narrower outputs (e.g., fewer frameworks or shorter question)"],
|
|
),
|
|
}
|
|
return json.dumps(truncated, ensure_ascii=False, indent=2, default=str)
|
|
|
|
|
|
def _mcp_text_result(
|
|
payload: Dict[str, Any], *, is_error: bool = False
|
|
) -> Dict[str, Any]:
|
|
result: Dict[str, Any] = {
|
|
"content": [{"type": "text", "text": _safe_json(payload)}]
|
|
}
|
|
if is_error:
|
|
result["isError"] = True
|
|
return result
|
|
|
|
|
|
TOOLS: List[Dict[str, Any]] = [
|
|
{
|
|
"name": "oracle_answer",
|
|
"description": "Answer a compliance/security question (optionally via NVIDIA LLM) and map to frameworks.",
|
|
"inputSchema": {
|
|
"type": "object",
|
|
"properties": {
|
|
"question": {
|
|
"type": "string",
|
|
"description": "The question to answer.",
|
|
},
|
|
"frameworks": {
|
|
"type": "array",
|
|
"items": {"type": "string"},
|
|
"description": "Frameworks to reference (e.g., ['NIST-CSF','ISO-27001','GDPR']).",
|
|
},
|
|
"mode": {
|
|
"type": "string",
|
|
"enum": ["strict", "advisory"],
|
|
"default": "strict",
|
|
"description": "strict=conservative, advisory=exploratory.",
|
|
},
|
|
"local_only": {
|
|
"type": "boolean",
|
|
"description": "If true, skip NVIDIA API calls (uses local-only mode). Defaults to true when NVIDIA_API_KEY is missing.",
|
|
},
|
|
},
|
|
"required": ["question"],
|
|
},
|
|
}
|
|
]
|
|
|
|
|
|
class OracleAnswerTools:
|
|
async def oracle_answer(
|
|
self,
|
|
*,
|
|
question: str,
|
|
frameworks: Optional[List[str]] = None,
|
|
mode: str = "strict",
|
|
local_only: Optional[bool] = None,
|
|
) -> Dict[str, Any]:
|
|
routing_action, shadow = layer0_entry(question)
|
|
if routing_action != "HANDOFF_TO_LAYER1":
|
|
return _layer0_payload(routing_action, shadow)
|
|
|
|
local_only_use = (
|
|
bool(local_only)
|
|
if local_only is not None
|
|
else not bool((os.getenv("NVIDIA_API_KEY") or "").strip())
|
|
)
|
|
|
|
try:
|
|
tool = OracleAnswerTool(
|
|
default_frameworks=frameworks,
|
|
use_local_only=local_only_use,
|
|
)
|
|
except Exception as e: # noqa: BLE001
|
|
return {
|
|
"ok": False,
|
|
"summary": str(e),
|
|
"data": {
|
|
"local_only": local_only_use,
|
|
"has_nvidia_api_key": bool(
|
|
(os.getenv("NVIDIA_API_KEY") or "").strip()
|
|
),
|
|
},
|
|
"truncated": False,
|
|
"next_steps": [
|
|
"Set NVIDIA_API_KEY to enable live answers",
|
|
"Or call oracle_answer(local_only=true, ...)",
|
|
],
|
|
}
|
|
|
|
resp = await tool.answer(question=question, frameworks=frameworks, mode=mode)
|
|
return {
|
|
"ok": True,
|
|
"summary": "Oracle answer generated.",
|
|
"data": {
|
|
"question": question,
|
|
"mode": mode,
|
|
"frameworks": frameworks or tool.default_frameworks,
|
|
"local_only": local_only_use,
|
|
"model": resp.model,
|
|
"answer": resp.answer,
|
|
"framework_hits": resp.framework_hits,
|
|
"reasoning": resp.reasoning,
|
|
},
|
|
"truncated": False,
|
|
"next_steps": [
|
|
"If the answer is incomplete, add more specifics to the question or include more frameworks.",
|
|
],
|
|
}
|
|
|
|
|
|
class StdioJsonRpc:
|
|
def __init__(self) -> None:
|
|
self._in = sys.stdin.buffer
|
|
self._out = sys.stdout.buffer
|
|
self._mode: str | None = None # "headers" | "line"
|
|
|
|
def read_message(self) -> Optional[Dict[str, Any]]:
|
|
while True:
|
|
if self._mode == "line":
|
|
line = self._in.readline()
|
|
if not line:
|
|
return None
|
|
raw = line.decode("utf-8", "replace").strip()
|
|
if not raw:
|
|
continue
|
|
try:
|
|
msg = json.loads(raw)
|
|
except Exception:
|
|
continue
|
|
if isinstance(msg, dict):
|
|
return msg
|
|
continue
|
|
|
|
first = self._in.readline()
|
|
if not first:
|
|
return None
|
|
|
|
if first in (b"\r\n", b"\n"):
|
|
continue
|
|
|
|
# Auto-detect newline-delimited JSON framing.
|
|
if self._mode is None and first.lstrip().startswith(b"{"):
|
|
try:
|
|
msg = json.loads(first.decode("utf-8", "replace"))
|
|
except Exception:
|
|
msg = None
|
|
if isinstance(msg, dict):
|
|
self._mode = "line"
|
|
return msg
|
|
|
|
headers: Dict[str, str] = {}
|
|
try:
|
|
text = first.decode("utf-8", "replace").strip()
|
|
except Exception:
|
|
continue
|
|
if ":" not in text:
|
|
continue
|
|
k, v = text.split(":", 1)
|
|
headers[k.lower().strip()] = v.strip()
|
|
|
|
while True:
|
|
line = self._in.readline()
|
|
if not line:
|
|
return None
|
|
if line in (b"\r\n", b"\n"):
|
|
break
|
|
try:
|
|
text = line.decode("utf-8", "replace").strip()
|
|
except Exception:
|
|
continue
|
|
if ":" not in text:
|
|
continue
|
|
k, v = text.split(":", 1)
|
|
headers[k.lower().strip()] = v.strip()
|
|
|
|
if "content-length" not in headers:
|
|
return None
|
|
try:
|
|
length = int(headers["content-length"])
|
|
except ValueError:
|
|
return None
|
|
body = self._in.read(length)
|
|
if not body:
|
|
return None
|
|
self._mode = "headers"
|
|
msg = json.loads(body.decode("utf-8", "replace"))
|
|
if isinstance(msg, dict):
|
|
return msg
|
|
return None
|
|
|
|
def write_message(self, message: Dict[str, Any]) -> None:
|
|
if self._mode == "line":
|
|
payload = json.dumps(
|
|
message, ensure_ascii=False, separators=(",", ":"), default=str
|
|
).encode("utf-8")
|
|
self._out.write(payload + b"\n")
|
|
self._out.flush()
|
|
return
|
|
|
|
body = json.dumps(
|
|
message, ensure_ascii=False, separators=(",", ":"), default=str
|
|
).encode("utf-8")
|
|
header = f"Content-Length: {len(body)}\r\n\r\n".encode("utf-8")
|
|
self._out.write(header)
|
|
self._out.write(body)
|
|
self._out.flush()
|
|
|
|
|
|
def main() -> None:
|
|
tools = OracleAnswerTools()
|
|
rpc = StdioJsonRpc()
|
|
|
|
handlers: Dict[str, Callable[[Dict[str, Any]], Any]] = {
|
|
"oracle_answer": lambda a: tools.oracle_answer(**a),
|
|
}
|
|
|
|
while True:
|
|
msg = rpc.read_message()
|
|
if msg is None:
|
|
return
|
|
|
|
method = msg.get("method")
|
|
msg_id = msg.get("id")
|
|
params = msg.get("params") or {}
|
|
|
|
try:
|
|
if method == "initialize":
|
|
result = {
|
|
"protocolVersion": "2024-11-05",
|
|
"serverInfo": {"name": "oracle_answer", "version": "0.1.0"},
|
|
"capabilities": {"tools": {}},
|
|
}
|
|
rpc.write_message({"jsonrpc": "2.0", "id": msg_id, "result": result})
|
|
continue
|
|
|
|
if method == "tools/list":
|
|
rpc.write_message(
|
|
{"jsonrpc": "2.0", "id": msg_id, "result": {"tools": TOOLS}}
|
|
)
|
|
continue
|
|
|
|
if method == "tools/call":
|
|
tool_name = str(params.get("name") or "")
|
|
args = params.get("arguments") or {}
|
|
handler = handlers.get(tool_name)
|
|
if not handler:
|
|
rpc.write_message(
|
|
{
|
|
"jsonrpc": "2.0",
|
|
"id": msg_id,
|
|
"result": _mcp_text_result(
|
|
{
|
|
"ok": False,
|
|
"summary": f"Unknown tool: {tool_name}",
|
|
"data": {"known_tools": sorted(handlers.keys())},
|
|
"truncated": False,
|
|
"next_steps": ["Call tools/list"],
|
|
},
|
|
is_error=True,
|
|
),
|
|
}
|
|
)
|
|
continue
|
|
|
|
payload = asyncio.run(handler(args)) # type: ignore[arg-type]
|
|
is_error = (
|
|
not bool(payload.get("ok", True))
|
|
if isinstance(payload, dict)
|
|
else False
|
|
)
|
|
rpc.write_message(
|
|
{
|
|
"jsonrpc": "2.0",
|
|
"id": msg_id,
|
|
"result": _mcp_text_result(payload, is_error=is_error),
|
|
}
|
|
)
|
|
continue
|
|
|
|
# Ignore notifications.
|
|
if msg_id is None:
|
|
continue
|
|
|
|
rpc.write_message(
|
|
{
|
|
"jsonrpc": "2.0",
|
|
"id": msg_id,
|
|
"result": _mcp_text_result(
|
|
{"ok": False, "summary": f"Unsupported method: {method}"},
|
|
is_error=True,
|
|
),
|
|
}
|
|
)
|
|
except Exception as e: # noqa: BLE001
|
|
if msg_id is not None:
|
|
rpc.write_message(
|
|
{
|
|
"jsonrpc": "2.0",
|
|
"id": msg_id,
|
|
"result": _mcp_text_result(
|
|
{"ok": False, "summary": f"fatal error: {e}"},
|
|
is_error=True,
|
|
),
|
|
}
|
|
)
|
|
|
|
|
|
def _layer0_payload(routing_action: str, shadow: ShadowEvalResult) -> Dict[str, Any]:
|
|
if routing_action == "FAIL_CLOSED":
|
|
return {"ok": False, "summary": "Layer 0: cannot comply with this request."}
|
|
if routing_action == "HANDOFF_TO_GUARDRAILS":
|
|
reason = shadow.reason or "governance_violation"
|
|
return {
|
|
"ok": False,
|
|
"summary": f"Layer 0: governance violation detected ({reason}).",
|
|
}
|
|
if routing_action == "PROMPT_FOR_CLARIFICATION":
|
|
return {
|
|
"ok": False,
|
|
"summary": "Layer 0: request is ambiguous. Please clarify and retry.",
|
|
}
|
|
return {"ok": False, "summary": "Layer 0: unrecognized routing action; refusing."}
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|