from __future__ import annotations import hashlib import json import os import urllib.error import urllib.parse import urllib.request from dataclasses import dataclass from datetime import datetime, timezone from pathlib import Path from typing import ( Any, Dict, Iterable, List, Mapping, MutableMapping, Optional, Sequence, Tuple, ) CF_API_BASE = "https://api.cloudflare.com/client/v4" def utc_now_iso() -> str: return datetime.now(timezone.utc).isoformat() def stable_hash(data: Any) -> str: blob = json.dumps( data, sort_keys=True, separators=(",", ":"), ensure_ascii=False ).encode("utf-8") return hashlib.sha256(blob).hexdigest() class CloudflareError(RuntimeError): pass @dataclass(frozen=True) class CloudflareContext: api_token: str account_id: str @staticmethod def from_env() -> "CloudflareContext": api_token = ( os.getenv("CLOUDFLARE_API_TOKEN") or os.getenv("CF_API_TOKEN") or os.getenv("CLOUDFLARE_TOKEN") or "" ).strip() account_id = ( os.getenv("CLOUDFLARE_ACCOUNT_ID") or os.getenv("CF_ACCOUNT_ID") or "" ).strip() if not api_token: raise CloudflareError( "Missing Cloudflare API token. Set CLOUDFLARE_API_TOKEN (or CF_API_TOKEN)." ) if not account_id: raise CloudflareError( "Missing Cloudflare account id. Set CLOUDFLARE_ACCOUNT_ID (or CF_ACCOUNT_ID)." ) return CloudflareContext(api_token=api_token, account_id=account_id) class CloudflareClient: def __init__(self, *, api_token: str) -> None: self.api_token = api_token def _request( self, method: str, path: str, *, params: Optional[Mapping[str, str]] = None, ) -> Dict[str, Any]: url = f"{CF_API_BASE}{path}" if params: url = f"{url}?{urllib.parse.urlencode(params)}" req = urllib.request.Request( url=url, method=method, headers={ "Authorization": f"Bearer {self.api_token}", "Accept": "application/json", "Content-Type": "application/json", }, ) try: with urllib.request.urlopen(req, timeout=30) as resp: raw = resp.read() except urllib.error.HTTPError as e: raw = e.read() if hasattr(e, "read") else b"" detail = raw.decode("utf-8", "replace") raise CloudflareError( f"Cloudflare API HTTP {e.code} for {path}: {detail}" ) from e except urllib.error.URLError as e: raise CloudflareError( f"Cloudflare API request failed for {path}: {e}" ) from e try: data = json.loads(raw.decode("utf-8", "replace")) except json.JSONDecodeError: raise CloudflareError( f"Cloudflare API returned non-JSON for {path}: {raw[:200]!r}" ) if not data.get("success", True): raise CloudflareError( f"Cloudflare API error for {path}: {data.get('errors')}" ) return data def paginate( self, path: str, *, params: Optional[Mapping[str, str]] = None, per_page: int = 100, max_pages: int = 5, ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: """ Fetch a paginated Cloudflare endpoint. Returns (results, result_info). """ results: List[Dict[str, Any]] = [] page = 1 last_info: Dict[str, Any] = {} while True: merged_params: Dict[str, str] = { "page": str(page), "per_page": str(per_page), } if params: merged_params.update({k: str(v) for k, v in params.items()}) data = self._request("GET", path, params=merged_params) batch = data.get("result") or [] if not isinstance(batch, list): batch = [batch] results.extend(batch) last_info = data.get("result_info") or {} total_pages = int(last_info.get("total_pages") or 1) if page >= total_pages or page >= max_pages: break page += 1 return results, last_info def list_zones(self) -> List[Dict[str, Any]]: zones, _info = self.paginate("/zones", max_pages=2) return zones def list_dns_records_summary( self, zone_id: str, *, max_pages: int = 1 ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: return self.paginate(f"/zones/{zone_id}/dns_records", max_pages=max_pages) def list_tunnels(self, account_id: str) -> List[Dict[str, Any]]: tunnels, _info = self.paginate( f"/accounts/{account_id}/cfd_tunnel", max_pages=2 ) return tunnels def list_tunnel_connections( self, account_id: str, tunnel_id: str ) -> List[Dict[str, Any]]: data = self._request( "GET", f"/accounts/{account_id}/cfd_tunnel/{tunnel_id}/connections" ) result = data.get("result") or [] return result if isinstance(result, list) else [result] def list_access_apps(self, account_id: str) -> List[Dict[str, Any]]: apps, _info = self.paginate(f"/accounts/{account_id}/access/apps", max_pages=3) return apps def list_access_policies( self, account_id: str, app_id: str ) -> List[Dict[str, Any]]: policies, _info = self.paginate( f"/accounts/{account_id}/access/apps/{app_id}/policies", max_pages=3, ) return policies @dataclass(frozen=True) class SnapshotMeta: snapshot_id: str created_at: str scopes: List[str] snapshot_path: str class SnapshotStore: def __init__(self, root_dir: Path) -> None: self.root_dir = root_dir self.snapshots_dir = root_dir / "snapshots" self.diffs_dir = root_dir / "diffs" self.snapshots_dir.mkdir(parents=True, exist_ok=True) self.diffs_dir.mkdir(parents=True, exist_ok=True) self._index: Dict[str, SnapshotMeta] = {} def get(self, snapshot_id: str) -> SnapshotMeta: if snapshot_id not in self._index: raise CloudflareError(f"Unknown snapshot_id: {snapshot_id}") return self._index[snapshot_id] def load_snapshot(self, snapshot_id: str) -> Dict[str, Any]: meta = self.get(snapshot_id) return json.loads(Path(meta.snapshot_path).read_text(encoding="utf-8")) def create_snapshot( self, *, client: CloudflareClient, ctx: CloudflareContext, scopes: Sequence[str], zone_id: Optional[str] = None, zone_name: Optional[str] = None, dns_max_pages: int = 1, ) -> Tuple[SnapshotMeta, Dict[str, Any]]: scopes_norm = sorted(set(scopes)) created_at = utc_now_iso() zones = client.list_zones() zones_min = [ { "id": z.get("id"), "name": z.get("name"), "status": z.get("status"), "paused": z.get("paused"), } for z in zones ] selected_zone_id = zone_id if not selected_zone_id and zone_name: for z in zones_min: if z.get("name") == zone_name: selected_zone_id = str(z.get("id")) break snapshot: Dict[str, Any] = { "meta": { "snapshot_id": "", "created_at": created_at, "account_id": ctx.account_id, "scopes": scopes_norm, }, "zones": zones_min, } if "tunnels" in scopes_norm: tunnels = client.list_tunnels(ctx.account_id) tunnels_min: List[Dict[str, Any]] = [] for t in tunnels: tid = t.get("id") name = t.get("name") status = t.get("status") connector_count: Optional[int] = None last_seen: Optional[str] = None if tid and status != "deleted": conns = client.list_tunnel_connections(ctx.account_id, str(tid)) connector_count = len(conns) # Pick the most recent 'opened_at' if present. opened = [c.get("opened_at") for c in conns if isinstance(c, dict)] opened = [o for o in opened if isinstance(o, str)] last_seen = max(opened) if opened else None tunnels_min.append( { "id": tid, "name": name, "status": status, "created_at": t.get("created_at"), "deleted_at": t.get("deleted_at"), "connector_count": connector_count, "last_seen": last_seen, } ) snapshot["tunnels"] = tunnels_min if "access_apps" in scopes_norm: apps = client.list_access_apps(ctx.account_id) apps_min = [ { "id": a.get("id"), "name": a.get("name"), "domain": a.get("domain"), "type": a.get("type"), "created_at": a.get("created_at"), "updated_at": a.get("updated_at"), } for a in apps ] snapshot["access_apps"] = apps_min if "dns" in scopes_norm: if selected_zone_id: records, info = client.list_dns_records_summary( selected_zone_id, max_pages=dns_max_pages ) records_min = [ { "id": r.get("id"), "type": r.get("type"), "name": r.get("name"), "content": r.get("content"), "proxied": r.get("proxied"), "ttl": r.get("ttl"), } for r in records ] snapshot["dns"] = { "zone_id": selected_zone_id, "zone_name": zone_name, "result_info": info, "records_sample": records_min, } else: snapshot["dns"] = { "note": "dns scope requested but no zone_id/zone_name provided; only zones list included", } snapshot_id = f"cf_{created_at.replace(':', '').replace('-', '').replace('.', '')}_{stable_hash(snapshot)[:10]}" snapshot["meta"]["snapshot_id"] = snapshot_id path = self.snapshots_dir / f"{snapshot_id}.json" path.write_text( json.dumps(snapshot, indent=2, ensure_ascii=False), encoding="utf-8" ) meta = SnapshotMeta( snapshot_id=snapshot_id, created_at=created_at, scopes=scopes_norm, snapshot_path=str(path), ) self._index[snapshot_id] = meta return meta, snapshot def diff( self, *, from_snapshot_id: str, to_snapshot_id: str, scopes: Optional[Sequence[str]] = None, ) -> Dict[str, Any]: before = self.load_snapshot(from_snapshot_id) after = self.load_snapshot(to_snapshot_id) scopes_before = set(before.get("meta", {}).get("scopes") or []) scopes_after = set(after.get("meta", {}).get("scopes") or []) scopes_all = sorted(scopes_before | scopes_after) scopes_use = sorted(set(scopes or scopes_all)) def index_by_id( items: Iterable[Mapping[str, Any]], ) -> Dict[str, Dict[str, Any]]: out: Dict[str, Dict[str, Any]] = {} for it in items: _id = it.get("id") if _id is None: continue out[str(_id)] = dict(it) return out diff_out: Dict[str, Any] = { "from": from_snapshot_id, "to": to_snapshot_id, "scopes": scopes_use, "changes": {}, } for scope in scopes_use: if scope not in {"tunnels", "access_apps", "zones"}: continue b_items = before.get(scope) or [] a_items = after.get(scope) or [] if not isinstance(b_items, list) or not isinstance(a_items, list): continue b_map = index_by_id(b_items) a_map = index_by_id(a_items) added = [a_map[k] for k in sorted(set(a_map) - set(b_map))] removed = [b_map[k] for k in sorted(set(b_map) - set(a_map))] changed: List[Dict[str, Any]] = [] for k in sorted(set(a_map) & set(b_map)): if stable_hash(a_map[k]) != stable_hash(b_map[k]): changed.append({"id": k, "before": b_map[k], "after": a_map[k]}) diff_out["changes"][scope] = { "added": [{"id": x.get("id"), "name": x.get("name")} for x in added], "removed": [ {"id": x.get("id"), "name": x.get("name")} for x in removed ], "changed": [ {"id": x.get("id"), "name": x.get("after", {}).get("name")} for x in changed ], "counts": { "added": len(added), "removed": len(removed), "changed": len(changed), }, } diff_path = self.diffs_dir / f"{from_snapshot_id}_to_{to_snapshot_id}.json" diff_path.write_text( json.dumps(diff_out, indent=2, ensure_ascii=False), encoding="utf-8", ) diff_out["diff_path"] = str(diff_path) return diff_out def parse_cloudflared_config_ingress(config_text: str) -> List[Dict[str, str]]: """ Best-effort parser for cloudflared YAML config ingress rules. We intentionally avoid a YAML dependency; this extracts common patterns: - hostname: example.com service: http://127.0.0.1:8080 """ rules: List[Dict[str, str]] = [] lines = config_text.splitlines() i = 0 while i < len(lines): line = lines[i] stripped = line.lstrip() if not stripped.startswith("-"): i += 1 continue after_dash = stripped[1:].lstrip() if not after_dash.startswith("hostname:"): i += 1 continue hostname = after_dash[len("hostname:") :].strip().strip('"').strip("'") base_indent = len(line) - len(line.lstrip()) service = "" j = i + 1 while j < len(lines): next_line = lines[j] if next_line.strip() == "": j += 1 continue next_indent = len(next_line) - len(next_line.lstrip()) if next_indent <= base_indent: break next_stripped = next_line.lstrip() if next_stripped.startswith("service:"): service = next_stripped[len("service:") :].strip().strip('"').strip("'") break j += 1 rules.append({"hostname": hostname, "service": service}) i = j return rules def ingress_summary_from_file( *, config_path: str, max_rules: int = 50, ) -> Dict[str, Any]: path = Path(config_path) if not path.exists(): raise CloudflareError(f"cloudflared config not found: {config_path}") text = path.read_text(encoding="utf-8", errors="replace") rules = parse_cloudflared_config_ingress(text) hostnames = sorted({r["hostname"] for r in rules if r.get("hostname")}) return { "config_path": config_path, "ingress_rule_count": len(rules), "hostnames": hostnames[:max_rules], "rules_sample": rules[:max_rules], "truncated": len(rules) > max_rules, }