| """SHIELD Diagnostic Script — full state analysis of all 4 layers.""" |
|
|
| from __future__ import annotations |
|
|
| import json |
| import os |
| import sys |
| from pathlib import Path |
|
|
| |
| if sys.platform == "win32": |
| sys.stdout.reconfigure(encoding="utf-8", errors="replace") |
|
|
| |
| project_root = Path(__file__).resolve().parent.parent |
| sys.path.insert(0, str(project_root)) |
|
|
| from mcphunter.config import SEED_ATTACKS_PATH, SHIELD_CONFIG, ShieldConfig |
| from mcphunter.shield.layer_heuristic import HeuristicLayer, _WEIGHTS |
| from mcphunter.shield.layer_llm import LLMJudgeLayer |
| from mcphunter.shield.models import Attack, Verdict |
| from mcphunter.shield.pipeline import ShieldPipeline |
|
|
|
|
| def load_seed_attacks() -> list[Attack]: |
| data = json.loads(SEED_ATTACKS_PATH.read_text(encoding="utf-8")) |
| return [Attack.from_dict(a) for a in data] |
|
|
|
|
| def section(title: str) -> None: |
| print(f"\n{'='*70}") |
| print(f" {title}") |
| print(f"{'='*70}\n") |
|
|
|
|
| def main() -> None: |
| attacks = load_seed_attacks() |
|
|
| |
| |
| |
| section("1. EVADING ATTACKS (L1-L3 deterministic scan)") |
|
|
| |
| config_no_llm = ShieldConfig(llm_layer_enabled=False) |
| pipeline_l123 = ShieldPipeline(config=config_no_llm) |
| heuristic = HeuristicLayer() |
|
|
| evading: list[Attack] = [] |
| for a in attacks: |
| result = pipeline_l123.scan(a.payload) |
| if result.verdict == Verdict.SAFE: |
| evading.append(a) |
|
|
| detected_l123 = len(attacks) - len(evading) |
| print(f"Detected by L1-L3: {detected_l123}/{len(attacks)} ({detected_l123/len(attacks):.0%})") |
| print(f"Evading (SAFE): {len(evading)}/{len(attacks)}") |
| print() |
|
|
| for a in evading: |
| payload_preview = a.payload[:80].replace("\n", "\\n") |
| print(f" {a.id:20s} | {a.attack_type.value:3s} | {payload_preview}...") |
|
|
| section("2. HEURISTIC SCORES FOR EVADING ATTACKS") |
|
|
| print(f" {'ID':20s} | {'entropy':>7s} | {'hidden':>7s} | {'instr':>7s} | {'urls':>7s} | {'length':>7s} | {'auth':>7s} | {'TOTAL':>7s}") |
| print(f" {'-'*20}-+-{'-'*7}-+-{'-'*7}-+-{'-'*7}-+-{'-'*7}-+-{'-'*7}-+-{'-'*7}-+-{'-'*7}") |
|
|
| for a in evading: |
| scores = heuristic._compute_feature_scores(a.payload) |
| ws = sum(scores[f] * _WEIGHTS[f] for f in _WEIGHTS) |
| print( |
| f" {a.id:20s} | {scores['entropy']:7.3f} | {scores['hidden_char_ratio']:7.3f} | " |
| f"{scores['instruction_density']:7.3f} | {scores['url_count']:7.3f} | " |
| f"{scores['length_anomaly']:7.3f} | {scores['authority_markers']:7.3f} | {ws:7.3f}" |
| ) |
|
|
| |
| |
| |
| section("3. GEMINI API KEY STATUS") |
|
|
| api_key = os.environ.get("GEMINI_API_KEY", "") |
| if api_key: |
| print(f" Key found (length={len(api_key)}, starts with {api_key[:8]}...)") |
| has_key = True |
| else: |
| print(" Key missing — set GEMINI_API_KEY environment variable") |
| has_key = False |
|
|
| |
| |
| |
| section("4. LAYER 4 (LLM JUDGE) ON FIRST 3 EVADING ATTACKS") |
|
|
| if not has_key: |
| print(" SKIPPED — no API key available") |
| elif len(evading) == 0: |
| print(" SKIPPED — no evading attacks to test") |
| else: |
| llm_layer = LLMJudgeLayer() |
| test_attacks = evading[:3] |
| for a in test_attacks: |
| print(f" Testing: {a.id} ({a.attack_type.value})") |
| print(f" Payload: {a.payload[:100].replace(chr(10), '\\n')}...") |
| result = llm_layer.scan(a.payload) |
| if result: |
| print(f" Verdict: {result.verdict.value} (confidence={result.confidence:.2f})") |
| print(f" Explanation: {result.explanation}") |
| else: |
| print(f" Verdict: safe (LLM returned no detection)") |
| print() |
|
|
| |
| |
| |
| section("5. FULL 4-LAYER PIPELINE (ALL 50 SEED ATTACKS)") |
|
|
| if not has_key: |
| print(" SKIPPED — no API key; showing L1-L3 results only") |
| print(f" Detection rate (L1-L3): {detected_l123}/{len(attacks)} ({detected_l123/len(attacks):.0%})") |
| print(f" Layer breakdown: L1={pipeline_l123.stats.layer1_catches} L2={pipeline_l123.stats.layer2_catches} L3={pipeline_l123.stats.layer3_catches}") |
| else: |
| pipeline_full = ShieldPipeline(config=SHIELD_CONFIG) |
| by_type: dict[str, list[int]] = {} |
| newly_caught: list[tuple[Attack, str]] = [] |
|
|
| for a in attacks: |
| result = pipeline_full.scan(a.payload) |
| t = a.attack_type.value |
| by_type.setdefault(t, [0, 0]) |
| by_type[t][0] += 1 |
| if result.verdict != Verdict.SAFE: |
| by_type[t][1] += 1 |
| if result.layer_triggered == "layer4_llm": |
| newly_caught.append((a, result.explanation)) |
|
|
| total = sum(v[0] for v in by_type.values()) |
| detected = sum(v[1] for v in by_type.values()) |
|
|
| print(f" Overall detection rate: {detected}/{total} ({detected/total:.0%})") |
| print() |
| print(f" Layer breakdown:") |
| print(f" L1 (regex): {pipeline_full.stats.layer1_catches}") |
| print(f" L2 (encoding): {pipeline_full.stats.layer2_catches}") |
| print(f" L3 (heuristic): {pipeline_full.stats.layer3_catches}") |
| print(f" L4 (LLM): {pipeline_full.stats.layer4_catches}") |
| print(f" Safe (missed): {pipeline_full.stats.safe_count}") |
| print() |
|
|
| print(f" Per-type breakdown:") |
| for t in sorted(by_type): |
| c, d = by_type[t] |
| status = "PERFECT" if d == c else f"{c - d} missed" |
| print(f" {t}: {d}/{c} ({status})") |
|
|
| if newly_caught: |
| print(f"\n Newly caught by Layer 4 ({len(newly_caught)} attacks):") |
| for a, explanation in newly_caught: |
| print(f" {a.id} ({a.attack_type.value}): {explanation[:100]}") |
|
|
| |
| |
| |
| section("6. PIPELINE CONFIGURATION") |
|
|
| config = SHIELD_CONFIG |
| print(f" Layer 1 (Regex): {'ENABLED' if config.regex_layer_enabled else 'DISABLED'}") |
| print(f" Layer 2 (Encoding): {'ENABLED' if config.encoding_layer_enabled else 'DISABLED'}") |
| print(f" Layer 3 (Heuristic): {'ENABLED' if config.heuristic_layer_enabled else 'DISABLED'}") |
| print(f" Layer 4 (LLM): {'ENABLED' if config.llm_layer_enabled else 'DISABLED'}") |
| print() |
| print(f" Execution order: L1 -> L2 -> L3 -> L4 (short-circuit on first detection)") |
| print(f" L4 model: {config.llm_judge_config.model_name}") |
| print(f" L4 rate limit: {config.llm_judge_config.requests_per_day} RPD") |
| print(f" L4 retry attempts: {config.llm_judge_config.retry_max_attempts}") |
| print(f" L4 temperature: {config.llm_judge_config.temperature}") |
|
|
| print(f"\n{'='*70}") |
| print(f" DIAGNOSTIC COMPLETE") |
| print(f"{'='*70}\n") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|