| """
|
| MITRE ATT&CK Cyber Knowledge Base Management Script
|
|
|
| This script manages the MITRE ATT&CK techniques knowledge base with:
|
| - Processing techniques.json file containing MITRE ATT&CK data
|
| - Semantic search using google/embeddinggemma-300m embeddings
|
| - Cross-encoder reranking using Qwen/Qwen3-Reranker-0.6B
|
| - Hybrid search combining ChromaDB (semantic) and BM25 (keyword)
|
| - Metadata filtering by tactics, platforms, and technique attributes
|
|
|
| Usage:
|
| python build_cyber_database.py ingest --techniques-json ./mitre_data/techniques.json
|
| python build_cyber_database.py test --query "process injection"
|
| python build_cyber_database.py test --interactive
|
| python build_cyber_database.py test --query "privilege escalation" --filter-tactics "privilege-escalation" --filter-platforms "Windows"
|
| """
|
|
|
| import argparse
|
| import os
|
| import sys
|
| from pathlib import Path
|
| from typing import Optional, List
|
|
|
|
|
| project_root = Path(__file__).parent.parent.parent
|
| sys.path.insert(0, str(project_root))
|
|
|
| from langchain.text_splitter import TokenTextSplitter
|
| from src.knowledge_base.cyber_knowledge_base import CyberKnowledgeBase
|
|
|
|
|
| def truncate_to_tokens(text: str, max_tokens: int = 300) -> str:
|
| """
|
| Truncate text to a maximum number of tokens using LangChain's TokenTextSplitter.
|
|
|
| Args:
|
| text: The text to truncate
|
| max_tokens: Maximum number of tokens (default: 300)
|
|
|
| Returns:
|
| Truncated text within the token limit
|
| """
|
| if not text:
|
| return ""
|
|
|
|
|
| cleaned_text = text.replace("\n", " ")
|
|
|
|
|
| splitter = TokenTextSplitter(
|
| encoding_name="cl100k_base", chunk_size=max_tokens, chunk_overlap=0
|
| )
|
|
|
| chunks = splitter.split_text(cleaned_text)
|
| return chunks[0] if chunks else ""
|
|
|
|
|
| def validate_techniques_file(techniques_json_path: str) -> bool:
|
| """Validate that techniques.json exists and is readable"""
|
|
|
| if not os.path.exists(techniques_json_path):
|
| print(f"[ERROR] Techniques file not found: {techniques_json_path}")
|
| return False
|
|
|
| try:
|
| import json
|
|
|
| with open(techniques_json_path, "r", encoding="utf-8") as f:
|
| data = json.load(f)
|
|
|
| if not isinstance(data, list):
|
| print(f"[ERROR] Invalid format: techniques.json should contain a list")
|
| return False
|
|
|
| if len(data) == 0:
|
| print(f"[ERROR] Empty techniques file")
|
| return False
|
|
|
|
|
| first_technique = data[0]
|
| required_fields = ["attack_id", "name", "description"]
|
| missing_fields = [
|
| field for field in required_fields if field not in first_technique
|
| ]
|
|
|
| if missing_fields:
|
| print(f"[ERROR] Missing required fields in techniques: {missing_fields}")
|
| return False
|
|
|
| print(f"[SUCCESS] Valid techniques file with {len(data)} techniques")
|
| return True
|
|
|
| except json.JSONDecodeError as e:
|
| print(f"[ERROR] Invalid JSON format: {e}")
|
| return False
|
| except Exception as e:
|
| print(f"[ERROR] Error reading techniques file: {e}")
|
| return False
|
|
|
|
|
| def ingest_techniques(args):
|
| """Ingest MITRE ATT&CK techniques and build knowledge base"""
|
|
|
| print("=" * 60)
|
| print("[INFO] INGESTING MITRE ATT&CK TECHNIQUES")
|
| print("=" * 60)
|
|
|
|
|
| if not validate_techniques_file(args.techniques_json):
|
| sys.exit(1)
|
|
|
|
|
| kb = CyberKnowledgeBase(embedding_model=args.embedding_model)
|
|
|
| try:
|
|
|
| kb.build_knowledge_base(
|
| techniques_json_path=args.techniques_json,
|
| persist_dir=args.persist_dir,
|
| reset=args.reset,
|
| )
|
|
|
|
|
| print("\n[INFO] Knowledge Base Statistics:")
|
| stats = kb.get_stats()
|
| for key, value in stats.items():
|
| if isinstance(value, dict):
|
| print(f" {key}:")
|
| for subkey, subvalue in list(value.items())[:5]:
|
| print(f" {subkey}: {subvalue}")
|
| if len(value) > 5:
|
| print(f" ... and {len(value) - 5} more")
|
| else:
|
| print(f" {key}: {value}")
|
|
|
| print(f"\n[SUCCESS] Knowledge base saved successfully to {args.persist_dir}!")
|
| return True
|
|
|
| except Exception as e:
|
| print(f"[ERROR] Error during ingestion: {e}")
|
| import traceback
|
|
|
| traceback.print_exc()
|
| return False
|
|
|
|
|
| def test_retrieval(args):
|
| """Test retrieval on existing knowledge base"""
|
|
|
| print("=" * 60)
|
| print("[INFO] TESTING CYBER KNOWLEDGE BASE")
|
| print("=" * 60)
|
|
|
|
|
| kb = CyberKnowledgeBase(embedding_model=args.embedding_model)
|
|
|
|
|
| success = kb.load_knowledge_base(persist_dir=args.persist_dir)
|
|
|
| if not success:
|
| print("[ERROR] Failed to load knowledge base. Run 'ingest' first.")
|
| sys.exit(1)
|
|
|
|
|
| print("\n[INFO] Knowledge Base Statistics:")
|
| stats = kb.get_stats()
|
| for key, value in stats.items():
|
| if isinstance(value, dict):
|
| print(f" {key}:")
|
| for subkey, subvalue in list(value.items())[:5]:
|
| print(f" {subkey}: {subvalue}")
|
| if len(value) > 5:
|
| print(f" ... and {len(value) - 5} more")
|
| else:
|
| print(f" {key}: {value}")
|
|
|
| if args.interactive:
|
|
|
| run_interactive_tests(kb)
|
| elif args.query:
|
|
|
| test_single_query(kb, args.query, args.filter_tactics, args.filter_platforms)
|
| else:
|
|
|
| run_test_suite(kb)
|
|
|
|
|
| def test_single_query(
|
| kb,
|
| query: str,
|
| filter_tactics: Optional[List[str]] = None,
|
| filter_platforms: Optional[List[str]] = None,
|
| ):
|
| """Test a single query with filters"""
|
|
|
| print(f"\n[INFO] Testing Query: '{query}'")
|
| if filter_tactics:
|
| print(f"[INFO] Filtering by tactics: {filter_tactics}")
|
| if filter_platforms:
|
| print(f"[INFO] Filtering by platforms: {filter_platforms}")
|
| print("-" * 40)
|
|
|
| try:
|
|
|
| results = kb.search(
|
| query,
|
| top_k=20,
|
| filter_tactics=filter_tactics,
|
| filter_platforms=filter_platforms,
|
| )
|
| display_detailed_results(results)
|
|
|
| except Exception as e:
|
| print(f"[ERROR] Error during search: {e}")
|
| import traceback
|
|
|
| traceback.print_exc()
|
|
|
|
|
| def display_detailed_results(results):
|
| """Display search results with detailed MITRE ATT&CK information"""
|
|
|
| if results:
|
| for i, doc in enumerate(results, 1):
|
| attack_id = doc.metadata.get("attack_id", "Unknown")
|
| name = doc.metadata.get("name", "Unknown")
|
| tactics_str = doc.metadata.get("tactics", "")
|
| platforms_str = doc.metadata.get("platforms", "")
|
| is_subtechnique = doc.metadata.get("is_subtechnique", False)
|
| mitigation_count = doc.metadata.get("mitigation_count", 0)
|
| mitigations = doc.metadata.get("mitigations", "")
|
|
|
|
|
| content_lines = doc.page_content.split("\n")
|
| description_line = next(
|
| (line for line in content_lines if line.startswith("Description:")), ""
|
| )
|
| if description_line:
|
| description = description_line.replace("Description: ", "")
|
| content_preview = truncate_to_tokens(description, 300)
|
| else:
|
| content_preview = truncate_to_tokens(doc.page_content, 300)
|
|
|
| mitigation_preview = truncate_to_tokens(mitigations, 300)
|
|
|
| print(f" {i}. {attack_id} - {name}")
|
| print(f" Type: {'Sub-technique' if is_subtechnique else 'Technique'}")
|
| print(f" Tactics: {tactics_str if tactics_str else 'None'}")
|
| print(f" Platforms: {platforms_str if platforms_str else 'None'}")
|
| print(
|
| f" Mitigations: {mitigation_preview if mitigation_preview else 'None'}"
|
| )
|
| print(f" Mitigation Count: {mitigation_count}")
|
| print(f" Description: {content_preview}")
|
| print()
|
| else:
|
| print(" No results found")
|
|
|
|
|
| def run_interactive_tests(kb):
|
| """Run interactive testing session with filtering options"""
|
|
|
| print("\n[INFO] Interactive Testing Mode")
|
| print("Available commands:")
|
| print(" - Enter a query to search")
|
| print(" - 'stats' to view knowledge base statistics")
|
| print(" - 'tactics' to list available tactics")
|
| print(" - 'platforms' to list available platforms")
|
| print(
|
| " - 'filter tactics:defense-evasion,privilege-escalation query' to filter by tactics"
|
| )
|
| print(" - 'filter platforms:Windows,Linux query' to filter by platforms")
|
| print(" - 'technique T1055' to get specific technique info")
|
| print(" - 'quit' to exit")
|
| print("-" * 50)
|
|
|
| while True:
|
| try:
|
| user_input = input("\n[INPUT] Enter command: ").strip()
|
|
|
| if user_input.lower() in ["quit", "exit", "q"]:
|
| break
|
|
|
| if not user_input:
|
| continue
|
|
|
|
|
| if user_input.lower() == "stats":
|
| display_stats(kb)
|
| continue
|
|
|
| if user_input.lower() == "tactics":
|
| display_available_tactics(kb)
|
| continue
|
|
|
| if user_input.lower() == "platforms":
|
| display_available_platforms(kb)
|
| continue
|
|
|
|
|
| if user_input.lower().startswith("technique "):
|
| technique_id = user_input.split(" ", 1)[1].strip()
|
| display_technique_info(kb, technique_id)
|
| continue
|
|
|
|
|
| filter_tactics = None
|
| filter_platforms = None
|
| query = user_input
|
|
|
| if user_input.lower().startswith("filter "):
|
|
|
| parts = user_input.split(" ")
|
| query_start = 1
|
|
|
| for i, part in enumerate(parts[1:], 1):
|
| if part.startswith("tactics:"):
|
| filter_tactics = part.split(":", 1)[1].split(",")
|
| query_start = i + 1
|
| elif part.startswith("platforms:"):
|
| filter_platforms = part.split(":", 1)[1].split(",")
|
| query_start = i + 1
|
| else:
|
| break
|
|
|
| query = " ".join(parts[query_start:])
|
|
|
| if not query.strip():
|
| print("[ERROR] No query provided")
|
| continue
|
|
|
|
|
| print(f"\n[INFO] Search: '{query}'")
|
| if filter_tactics:
|
| print(f"[INFO] Filtering by tactics: {filter_tactics}")
|
| if filter_platforms:
|
| print(f"[INFO] Filtering by platforms: {filter_platforms}")
|
|
|
| results = kb.search(
|
| query,
|
| top_k=20,
|
| filter_tactics=filter_tactics,
|
| filter_platforms=filter_platforms,
|
| )
|
| display_detailed_results(results)
|
|
|
| except KeyboardInterrupt:
|
| print("\n[INFO] Exiting interactive mode...")
|
| break
|
| except Exception as e:
|
| print(f"[ERROR] Error: {e}")
|
|
|
|
|
| def display_stats(kb):
|
| """Display detailed knowledge base statistics"""
|
| stats = kb.get_stats()
|
| print("\n[INFO] Knowledge Base Statistics:")
|
| for key, value in stats.items():
|
| if isinstance(value, dict):
|
| print(f" {key}:")
|
| for subkey, subvalue in value.items():
|
| print(f" {subkey}: {subvalue}")
|
| else:
|
| print(f" {key}: {value}")
|
|
|
|
|
| def display_available_tactics(kb):
|
| """Display available tactics"""
|
| stats = kb.get_stats()
|
| tactics = stats.get("techniques_by_tactic", {})
|
| if tactics:
|
| print("\n[INFO] Available Tactics:")
|
| for tactic, count in sorted(tactics.items()):
|
| print(f" {tactic}: {count} techniques")
|
| else:
|
| print("\n[INFO] No tactics information available")
|
|
|
|
|
| def display_available_platforms(kb):
|
| """Display available platforms"""
|
| stats = kb.get_stats()
|
| platforms = stats.get("techniques_by_platform", {})
|
| if platforms:
|
| print("\n[INFO] Available Platforms:")
|
| for platform, count in sorted(platforms.items()):
|
| print(f" {platform}: {count} techniques")
|
| else:
|
| print("\n[INFO] No platforms information available")
|
|
|
|
|
| def display_technique_info(kb, technique_id: str):
|
| """Display detailed information about a specific technique"""
|
| technique = kb.get_technique_by_id(technique_id.upper())
|
| if technique:
|
| print(f"\n[INFO] Technique Details: {technique_id}")
|
| print("-" * 40)
|
| print(f"Name: {technique.get('name', 'Unknown')}")
|
| print(
|
| f"Type: {'Sub-technique' if technique.get('is_subtechnique') else 'Technique'}"
|
| )
|
| print(f"Tactics: {', '.join(technique.get('tactics', []))}")
|
| print(f"Platforms: {', '.join(technique.get('platforms', []))}")
|
| print(f"Mitigations: {len(technique.get('mitigations', []))}")
|
|
|
| description = technique.get("description", "")
|
| if description:
|
| print(
|
| f"Description: {description[:500]}{'...' if len(description) > 500 else ''}"
|
| )
|
|
|
| detection = technique.get("detection", "")
|
| if detection:
|
| print(
|
| f"Detection: {detection[:300]}{'...' if len(detection) > 300 else ''}"
|
| )
|
| else:
|
| print(f"\n[ERROR] Technique {technique_id} not found")
|
|
|
|
|
| def run_test_suite(kb):
|
| """Run comprehensive test suite for cyber techniques"""
|
|
|
| test_cases = [
|
|
|
| {"query": "process injection", "description": "Process injection techniques"},
|
| {"query": "DLL injection", "description": "DLL injection methods"},
|
|
|
| {
|
| "query": "privilege escalation Windows",
|
| "description": "Windows privilege escalation",
|
| },
|
| {"query": "UAC bypass", "description": "UAC bypass techniques"},
|
|
|
| {
|
| "query": "scheduled task persistence",
|
| "description": "Scheduled task persistence",
|
| },
|
| {"query": "registry persistence", "description": "Registry-based persistence"},
|
|
|
| {
|
| "query": "credential dumping LSASS",
|
| "description": "LSASS credential dumping",
|
| },
|
| {"query": "password spraying", "description": "Password spraying attacks"},
|
|
|
| {
|
| "query": "defense evasion DLL hijacking",
|
| "description": "DLL hijacking evasion",
|
| },
|
| {"query": "process hollowing", "description": "Process hollowing technique"},
|
|
|
| {"query": "lateral movement SMB", "description": "SMB lateral movement"},
|
| {"query": "remote desktop protocol", "description": "RDP-based movement"},
|
| ]
|
|
|
| print("\n[INFO] Running Cyber Security Test Suite:")
|
| print("=" * 50)
|
|
|
| for i, test_case in enumerate(test_cases, 1):
|
| print(f"\n#{i} {test_case['description']}")
|
| print(f"Query: '{test_case['query']}'")
|
| print("-" * 30)
|
|
|
| try:
|
| results = kb.search(test_case["query"], top_k=3)
|
| display_detailed_results(results)
|
| except Exception as e:
|
| print(f"[ERROR] Error: {e}")
|
|
|
|
|
| def main():
|
| """Main entry point with argument parsing"""
|
|
|
| parser = argparse.ArgumentParser(
|
| description="MITRE ATT&CK Cyber Knowledge Base Management",
|
| formatter_class=argparse.RawDescriptionHelpFormatter,
|
| epilog="""
|
| Examples:
|
| python build_cyber_database.py ingest --techniques-json ./mitre_data/techniques.json
|
| python build_cyber_database.py test --query "process injection"
|
| python build_cyber_database.py test --interactive
|
| python build_cyber_database.py test --query "privilege escalation" --filter-tactics "privilege-escalation"
|
| """,
|
| )
|
|
|
|
|
| subparsers = parser.add_subparsers(dest="command", help="Available commands")
|
|
|
|
|
| ingest_parser = subparsers.add_parser(
|
| "ingest", help="Ingest MITRE ATT&CK techniques and build knowledge base"
|
| )
|
| ingest_parser.add_argument(
|
| "--techniques-json",
|
| default="./mitre_data/techniques.json",
|
| help="Path to techniques.json file",
|
| )
|
| ingest_parser.add_argument(
|
| "--persist-dir",
|
| default="./cyber_knowledge_base",
|
| help="Directory to store the knowledge base",
|
| )
|
| ingest_parser.add_argument(
|
| "--embedding-model",
|
| default="google/embeddinggemma-300m",
|
| help="Embedding model name",
|
| )
|
| ingest_parser.add_argument(
|
| "--reset",
|
| action="store_true",
|
| default=True,
|
| help="Reset knowledge base before ingestion (default: True)",
|
| )
|
| ingest_parser.add_argument(
|
| "--no-reset",
|
| dest="reset",
|
| action="store_false",
|
| help="Do not reset existing knowledge base",
|
| )
|
|
|
|
|
| test_parser = subparsers.add_parser(
|
| "test", help="Test retrieval on existing knowledge base"
|
| )
|
| test_parser.add_argument("--query", help="Single query to test")
|
| test_parser.add_argument(
|
| "--filter-tactics",
|
| nargs="+",
|
| help="Filter by tactics (e.g., --filter-tactics defense-evasion privilege-escalation)",
|
| )
|
| test_parser.add_argument(
|
| "--filter-platforms",
|
| nargs="+",
|
| help="Filter by platforms (e.g., --filter-platforms Windows Linux)",
|
| )
|
| test_parser.add_argument(
|
| "--interactive", action="store_true", help="Interactive testing mode"
|
| )
|
| test_parser.add_argument(
|
| "--persist-dir",
|
| default="./cyber_knowledge_base",
|
| help="Directory where knowledge base is stored",
|
| )
|
| test_parser.add_argument(
|
| "--embedding-model",
|
| default="google/embeddinggemma-300m",
|
| help="Embedding model name",
|
| )
|
|
|
| args = parser.parse_args()
|
|
|
| if args.command == "ingest":
|
| success = ingest_techniques(args)
|
| sys.exit(0 if success else 1)
|
|
|
| elif args.command == "test":
|
| test_retrieval(args)
|
|
|
| else:
|
| parser.print_help()
|
| sys.exit(1)
|
|
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|