from __future__ import annotations import asyncio import json import os import sys from typing import Any, Callable, Dict, List, Optional from layer0 import layer0_entry from layer0.shadow_classifier import ShadowEvalResult from .tool import OracleAnswerTool MAX_BYTES_DEFAULT = 32_000 def _max_bytes() -> int: raw = (os.getenv("VM_MCP_MAX_BYTES") or "").strip() if not raw: return MAX_BYTES_DEFAULT try: return max(4_096, int(raw)) except ValueError: return MAX_BYTES_DEFAULT def _redact(obj: Any) -> Any: sensitive_keys = ("token", "secret", "password", "private", "key", "certificate") if isinstance(obj, dict): out: Dict[str, Any] = {} for k, v in obj.items(): if any(s in str(k).lower() for s in sensitive_keys): out[k] = "" else: out[k] = _redact(v) return out if isinstance(obj, list): return [_redact(v) for v in obj] if isinstance(obj, str): if obj.startswith("ghp_") or obj.startswith("github_pat_"): return "" return obj return obj def _safe_json(payload: Dict[str, Any]) -> str: payload = _redact(payload) raw = json.dumps(payload, ensure_ascii=False, separators=(",", ":"), default=str) if len(raw.encode("utf-8")) <= _max_bytes(): return json.dumps(payload, ensure_ascii=False, indent=2, default=str) truncated = { "ok": payload.get("ok", True), "truncated": True, "summary": payload.get("summary", "Response exceeded max size; truncated."), "next_steps": payload.get( "next_steps", ["request narrower outputs (e.g., fewer frameworks or shorter question)"], ), } return json.dumps(truncated, ensure_ascii=False, indent=2, default=str) def _mcp_text_result( payload: Dict[str, Any], *, is_error: bool = False ) -> Dict[str, Any]: result: Dict[str, Any] = { "content": [{"type": "text", "text": _safe_json(payload)}] } if is_error: result["isError"] = True return result TOOLS: List[Dict[str, Any]] = [ { "name": "oracle_answer", "description": "Answer a compliance/security question (optionally via NVIDIA LLM) and map to frameworks.", "inputSchema": { "type": "object", "properties": { "question": { "type": "string", "description": "The question to answer.", }, "frameworks": { "type": "array", "items": {"type": "string"}, "description": "Frameworks to reference (e.g., ['NIST-CSF','ISO-27001','GDPR']).", }, "mode": { "type": "string", "enum": ["strict", "advisory"], "default": "strict", "description": "strict=conservative, advisory=exploratory.", }, "local_only": { "type": "boolean", "description": "If true, skip NVIDIA API calls (uses local-only mode). Defaults to true when NVIDIA_API_KEY is missing.", }, }, "required": ["question"], }, } ] class OracleAnswerTools: async def oracle_answer( self, *, question: str, frameworks: Optional[List[str]] = None, mode: str = "strict", local_only: Optional[bool] = None, ) -> Dict[str, Any]: routing_action, shadow = layer0_entry(question) if routing_action != "HANDOFF_TO_LAYER1": return _layer0_payload(routing_action, shadow) local_only_use = ( bool(local_only) if local_only is not None else not bool((os.getenv("NVIDIA_API_KEY") or "").strip()) ) try: tool = OracleAnswerTool( default_frameworks=frameworks, use_local_only=local_only_use, ) except Exception as e: # noqa: BLE001 return { "ok": False, "summary": str(e), "data": { "local_only": local_only_use, "has_nvidia_api_key": bool( (os.getenv("NVIDIA_API_KEY") or "").strip() ), }, "truncated": False, "next_steps": [ "Set NVIDIA_API_KEY to enable live answers", "Or call oracle_answer(local_only=true, ...)", ], } resp = await tool.answer(question=question, frameworks=frameworks, mode=mode) return { "ok": True, "summary": "Oracle answer generated.", "data": { "question": question, "mode": mode, "frameworks": frameworks or tool.default_frameworks, "local_only": local_only_use, "model": resp.model, "answer": resp.answer, "framework_hits": resp.framework_hits, "reasoning": resp.reasoning, }, "truncated": False, "next_steps": [ "If the answer is incomplete, add more specifics to the question or include more frameworks.", ], } class StdioJsonRpc: def __init__(self) -> None: self._in = sys.stdin.buffer self._out = sys.stdout.buffer self._mode: str | None = None # "headers" | "line" def read_message(self) -> Optional[Dict[str, Any]]: while True: if self._mode == "line": line = self._in.readline() if not line: return None raw = line.decode("utf-8", "replace").strip() if not raw: continue try: msg = json.loads(raw) except Exception: continue if isinstance(msg, dict): return msg continue first = self._in.readline() if not first: return None if first in (b"\r\n", b"\n"): continue # Auto-detect newline-delimited JSON framing. if self._mode is None and first.lstrip().startswith(b"{"): try: msg = json.loads(first.decode("utf-8", "replace")) except Exception: msg = None if isinstance(msg, dict): self._mode = "line" return msg headers: Dict[str, str] = {} try: text = first.decode("utf-8", "replace").strip() except Exception: continue if ":" not in text: continue k, v = text.split(":", 1) headers[k.lower().strip()] = v.strip() while True: line = self._in.readline() if not line: return None if line in (b"\r\n", b"\n"): break try: text = line.decode("utf-8", "replace").strip() except Exception: continue if ":" not in text: continue k, v = text.split(":", 1) headers[k.lower().strip()] = v.strip() if "content-length" not in headers: return None try: length = int(headers["content-length"]) except ValueError: return None body = self._in.read(length) if not body: return None self._mode = "headers" msg = json.loads(body.decode("utf-8", "replace")) if isinstance(msg, dict): return msg return None def write_message(self, message: Dict[str, Any]) -> None: if self._mode == "line": payload = json.dumps( message, ensure_ascii=False, separators=(",", ":"), default=str ).encode("utf-8") self._out.write(payload + b"\n") self._out.flush() return body = json.dumps( message, ensure_ascii=False, separators=(",", ":"), default=str ).encode("utf-8") header = f"Content-Length: {len(body)}\r\n\r\n".encode("utf-8") self._out.write(header) self._out.write(body) self._out.flush() def main() -> None: tools = OracleAnswerTools() rpc = StdioJsonRpc() handlers: Dict[str, Callable[[Dict[str, Any]], Any]] = { "oracle_answer": lambda a: tools.oracle_answer(**a), } while True: msg = rpc.read_message() if msg is None: return method = msg.get("method") msg_id = msg.get("id") params = msg.get("params") or {} try: if method == "initialize": result = { "protocolVersion": "2024-11-05", "serverInfo": {"name": "oracle_answer", "version": "0.1.0"}, "capabilities": {"tools": {}}, } rpc.write_message({"jsonrpc": "2.0", "id": msg_id, "result": result}) continue if method == "tools/list": rpc.write_message( {"jsonrpc": "2.0", "id": msg_id, "result": {"tools": TOOLS}} ) continue if method == "tools/call": tool_name = str(params.get("name") or "") args = params.get("arguments") or {} handler = handlers.get(tool_name) if not handler: rpc.write_message( { "jsonrpc": "2.0", "id": msg_id, "result": _mcp_text_result( { "ok": False, "summary": f"Unknown tool: {tool_name}", "data": {"known_tools": sorted(handlers.keys())}, "truncated": False, "next_steps": ["Call tools/list"], }, is_error=True, ), } ) continue payload = asyncio.run(handler(args)) # type: ignore[arg-type] is_error = ( not bool(payload.get("ok", True)) if isinstance(payload, dict) else False ) rpc.write_message( { "jsonrpc": "2.0", "id": msg_id, "result": _mcp_text_result(payload, is_error=is_error), } ) continue # Ignore notifications. if msg_id is None: continue rpc.write_message( { "jsonrpc": "2.0", "id": msg_id, "result": _mcp_text_result( {"ok": False, "summary": f"Unsupported method: {method}"}, is_error=True, ), } ) except Exception as e: # noqa: BLE001 if msg_id is not None: rpc.write_message( { "jsonrpc": "2.0", "id": msg_id, "result": _mcp_text_result( {"ok": False, "summary": f"fatal error: {e}"}, is_error=True, ), } ) def _layer0_payload(routing_action: str, shadow: ShadowEvalResult) -> Dict[str, Any]: if routing_action == "FAIL_CLOSED": return {"ok": False, "summary": "Layer 0: cannot comply with this request."} if routing_action == "HANDOFF_TO_GUARDRAILS": reason = shadow.reason or "governance_violation" return { "ok": False, "summary": f"Layer 0: governance violation detected ({reason}).", } if routing_action == "PROMPT_FOR_CLARIFICATION": return { "ok": False, "summary": "Layer 0: request is ambiguous. Please clarify and retry.", } return {"ok": False, "summary": "Layer 0: unrecognized routing action; refusing."} if __name__ == "__main__": main()