Files
vm-cloudflare/mcp/oracle_answer/tool.py
Vault Sovereign f0b8d962de
Some checks failed
WAF Intelligence Guardrail / waf-intel (push) Waiting to run
Cloudflare Registry Validation / validate-registry (push) Has been cancelled
chore: pre-migration snapshot
Layer0, MCP servers, Terraform consolidation
2025-12-27 01:52:27 +00:00

215 lines
7.0 KiB
Python

"""
Core oracle tool implementation with NVIDIA AI integration.
This module contains the logic that answers compliance questions using
NVIDIA's API (free tier from build.nvidia.com).
Separate from CLI/API wrapper for clean testability.
"""
from __future__ import annotations
import asyncio
import json
import os
import urllib.error
import urllib.request
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
try:
import httpx
except ImportError:
httpx = None
@dataclass
class ToolResponse:
"""Canonical response from the oracle tool."""
answer: str
framework_hits: Dict[str, List[str]]
reasoning: Optional[str] = None
raw_context: Optional[Dict[str, Any]] = None
model: str = "nvidia"
class OracleAnswerTool:
"""
Compliance / security oracle powered by NVIDIA AI.
This tool:
- takes `question`, `frameworks`, `mode`, etc.
- queries NVIDIA's LLM API (free tier)
- searches local documentation for context
- assembles structured ToolResponse with framework mapping
"""
# NVIDIA API configuration
NVIDIA_API_BASE = "https://integrate.api.nvidia.com/v1"
NVIDIA_MODEL = "meta/llama-2-7b-chat" # Free tier model
def __init__(
self,
*,
default_frameworks: Optional[List[str]] = None,
api_key: Optional[str] = None,
use_local_only: bool = False,
) -> None:
"""
Initialize oracle with NVIDIA API integration.
Args:
default_frameworks: Default compliance frameworks to use
api_key: NVIDIA API key (defaults to NVIDIA_API_KEY env var)
use_local_only: If True, skip LLM calls (for testing)
"""
self.default_frameworks = default_frameworks or ["NIST-CSF", "ISO-27001"]
self.api_key = api_key or os.environ.get("NVIDIA_API_KEY")
self.use_local_only = use_local_only
if not self.use_local_only and not self.api_key:
raise ValueError(
"NVIDIA_API_KEY not found. Set it in .env or pass api_key parameter."
)
def _extract_framework_hits(
self, answer: str, frameworks: List[str]
) -> Dict[str, List[str]]:
"""Extract mentions of frameworks from the LLM answer."""
hits = {fw: [] for fw in frameworks}
answer_lower = answer.lower()
for framework in frameworks:
# Simple keyword matching for framework mentions
if framework.lower() in answer_lower:
# Extract sentences containing the framework
sentences = answer.split(".")
for sentence in sentences:
if framework.lower() in sentence.lower():
hits[framework].append(sentence.strip())
return hits
async def _call_nvidia_api(self, prompt: str) -> str:
"""Call NVIDIA's API to get LLM response."""
if self.use_local_only:
return "Local-only mode: skipping NVIDIA API call"
headers = {
"Authorization": f"Bearer {self.api_key}",
"Accept": "application/json",
"Content-Type": "application/json",
}
payload = {
"model": self.NVIDIA_MODEL,
"messages": [{"role": "user", "content": prompt}],
"temperature": 0.7,
"top_p": 0.9,
"max_tokens": 1024,
}
# Prefer httpx when available; otherwise fall back to stdlib urllib to avoid extra deps.
if httpx:
try:
async with httpx.AsyncClient() as client:
response = await client.post(
f"{self.NVIDIA_API_BASE}/chat/completions",
json=payload,
headers=headers,
timeout=30.0,
)
response.raise_for_status()
data = response.json()
return data["choices"][0]["message"]["content"]
except Exception as e: # noqa: BLE001
return f"(API Error: {str(e)}) Falling back to local analysis..."
def _urllib_post() -> str:
req = urllib.request.Request(
url=f"{self.NVIDIA_API_BASE}/chat/completions",
method="POST",
headers=headers,
data=json.dumps(payload, ensure_ascii=False).encode("utf-8"),
)
try:
with urllib.request.urlopen(req, timeout=30) as resp:
raw = resp.read().decode("utf-8", "replace")
data = json.loads(raw)
return data["choices"][0]["message"]["content"]
except urllib.error.HTTPError as e:
detail = ""
try:
detail = e.read().decode("utf-8", "replace")
except Exception:
detail = str(e)
raise RuntimeError(f"HTTP {e.code}: {detail}") from e
try:
return await asyncio.to_thread(_urllib_post)
except Exception as e: # noqa: BLE001
return f"(API Error: {str(e)}) Falling back to local analysis..."
async def answer(
self,
question: str,
frameworks: Optional[List[str]] = None,
mode: str = "strict",
) -> ToolResponse:
"""
Main entry point for MCP / clients.
Args:
question: Compliance question to answer
frameworks: Frameworks to reference (default: NIST-CSF, ISO-27001)
mode: "strict" (conservative) or "advisory" (exploratory)
Returns:
ToolResponse with answer, framework hits, and reasoning
"""
frameworks = frameworks or self.default_frameworks
# Build context-aware prompt for NVIDIA API
mode_instruction = (
"conservative and cautious, assuming worst-case scenarios"
if mode == "strict"
else "exploratory and comprehensive"
)
prompt = f"""You are a compliance and security expert analyzing infrastructure questions.
Question: {question}
Compliance Frameworks to Consider:
{chr(10).join(f"- {fw}" for fw in frameworks)}
Analysis Mode: {mode_instruction}
Provide a structured answer that:
1. Directly addresses the question
2. References the relevant frameworks
3. Identifies gaps or risks
4. Suggests mitigations where applicable
Be concise but thorough."""
# Call NVIDIA API for actual LLM response
answer = await self._call_nvidia_api(prompt)
# Extract framework mentions from the response
framework_hits = self._extract_framework_hits(answer, frameworks)
# Generate reasoning based on mode
reasoning = (
f"Analyzed question against frameworks: {', '.join(frameworks)}. "
f"Mode={mode}. Used NVIDIA LLM for compliance analysis."
)
return ToolResponse(
answer=answer,
framework_hits=framework_hits,
reasoning=reasoning,
model="nvidia/llama-2-7b-chat",
)