| import networkx as nx |
| import json |
| import os |
| import asyncio |
| import nest_asyncio |
| import tqdm |
| |
| import os.path |
| import tempfile |
| import subprocess |
| from typing import List, Optional, Dict |
| import logging |
| import urllib.parse |
|
|
| from .ModelService import create_model_service |
| from .Node import Node, DirectoryNode, FileNode, ChunkNode, EntityNode |
| from .CodeParser import CodeParser |
| from .EntityExtractor import HybridEntityExtractor |
| from .CodeIndex import CodeIndex |
| from .utils.logger_utils import setup_logger |
| from .utils.parsing_utils import read_directory_files_recursively, get_language_from_filename |
| from .utils.path_utils import prepare_input_path, build_entity_alias_map, resolve_entity_call |
| from .EntityChunkMapper import EntityChunkMapper |
|
|
| LOGGER_NAME = 'REPO_KNOWLEDGE_GRAPH_LOGGER' |
|
|
| MODEL_SERVICE_TYPES = ['openai', 'sentence-transformers'] |
|
|
|
|
| |
| class RepoKnowledgeGraph: |
| """ |
| RepoKnowledgeGraph builds a knowledge graph of a code repository. |
| It parses source files, extracts code entities and relationships, and organizes them |
| into a directed acyclic graph (DAG) with additional semantic edges. |
| |
| Use `from_path()` or `load_graph_from_file()` to create instances. |
| """ |
|
|
| def __init__(self): |
| """ |
| Private constructor. Use from_path() or load_graph_from_file() instead. |
| """ |
| raise RuntimeError( |
| "Cannot instantiate RepoKnowledgeGraph directly. " |
| "Use RepoKnowledgeGraph.from_path() or RepoKnowledgeGraph.load_graph_from_file() instead." |
| ) |
|
|
| def _initialize(self, model_service_kwargs: dict, code_index_kwargs: Optional[dict] = None): |
| """Internal initialization method.""" |
| setup_logger(LOGGER_NAME) |
| self.logger = logging.getLogger(LOGGER_NAME) |
| self.logger.info('Initializing RepoKnowledgeGraph instance.') |
| self.code_parser = CodeParser() |
| |
| |
| index_type = (code_index_kwargs or {}).get('index_type', 'hybrid') |
| skip_embedder = index_type == 'keyword-only' |
| if skip_embedder: |
| self.logger.info('Using keyword-only index, skipping embedder initialization') |
| |
| self.model_service = create_model_service(skip_embedder=skip_embedder, **model_service_kwargs) |
| self.entities = {} |
| self.graph = nx.DiGraph() |
| self.knowledge_graph = nx.DiGraph() |
| self.code_index = None |
| self.entity_extractor = HybridEntityExtractor() |
|
|
| def __iter__(self): |
| |
| return (node_data['data'] for _, node_data in self.graph.nodes(data=True)) |
|
|
| def __getitem__(self, node_id): |
| return self.graph.nodes[node_id]['data'] |
| |
|
|
| @classmethod |
| def from_path(cls, path: str, skip_dirs: Optional[list] = None, index_nodes: bool = True, describe_nodes=False, |
| extract_entities: bool = False, model_service_kwargs: Optional[dict] = None, code_index_kwargs: Optional[dict] = None): |
| if skip_dirs is None: |
| skip_dirs = [] |
| if model_service_kwargs is None: |
| model_service_kwargs = {} |
| """ |
| Alternative constructor to build a RepoKnowledgeGraph from a path, with options to skip directories |
| and control entity extraction and node description. |
| |
| Args: |
| path (str): Path to the root of the code repository. |
| skip_dirs (list): List of directory names to skip. |
| index_nodes (bool): Whether to build a code index. |
| describe_nodes (bool): Whether to generate descriptions for code chunks. |
| extract_entities (bool): Whether to extract entities from code. |
| |
| Returns: |
| RepoKnowledgeGraph: The constructed knowledge graph. |
| """ |
| instance = cls.__new__(cls) |
| instance._initialize(model_service_kwargs=model_service_kwargs, code_index_kwargs=code_index_kwargs) |
|
|
| instance.logger.info(f"Preparing to build knowledge graph from path: {path}") |
|
|
| prepared_path = prepare_input_path(path) |
| instance.logger.debug(f"Prepared input path: {prepared_path}") |
|
|
| |
| try: |
| loop = asyncio.get_running_loop() |
| except RuntimeError: |
| loop = None |
|
|
| if loop and loop.is_running(): |
| instance.logger.debug("Detected running event loop, applying nest_asyncio.") |
| nest_asyncio.apply() |
| task = instance._initial_parse_path_async(prepared_path, skip_dirs=skip_dirs, index_nodes=index_nodes, |
| describe_nodes=describe_nodes, extract_entities=extract_entities) |
| loop.run_until_complete(task) |
| else: |
| instance.logger.debug("No running event loop, using asyncio.run.") |
| asyncio.run(instance._initial_parse_path_async(prepared_path, skip_dirs=skip_dirs, index_nodes=index_nodes, |
| describe_nodes=describe_nodes, |
| extract_entities=extract_entities)) |
|
|
| instance.logger.info("Parsing files and building initial nodes...") |
| instance.logger.info("Initial parse and node creation complete. Building relationships between nodes...") |
| instance._build_relationships() |
|
|
| if index_nodes: |
| instance.logger.info("Building code index for all nodes in the graph...") |
| instance.code_index = CodeIndex(list(instance), model_service=instance.model_service, **(code_index_kwargs or {})) |
|
|
| instance.logger.info("Knowledge graph construction from path completed successfully.") |
| return instance |
|
|
| @classmethod |
| def from_repo( |
| cls, |
| repo_url: str, |
| skip_dirs: Optional[list] = None, |
| index_nodes: bool = True, |
| describe_nodes: bool = False, |
| extract_entities: bool = False, |
| model_service_kwargs: Optional[dict] = None, |
| code_index_kwargs: Optional[dict]=None, |
| github_token: Optional[str] = None, |
| allow_unauthenticated_clone: bool = True, |
| ): |
| """ |
| Alternative constructor to build a RepoKnowledgeGraph from a remote git repository URL. |
| |
| Args: |
| repo_url (str): Git repository URL (SSH or HTTPS). |
| skip_dirs (list): List of directory names to skip. |
| index_nodes (bool): Whether to build a code index. |
| describe_nodes (bool): Whether to generate descriptions for code chunks. |
| extract_entities (bool): Whether to extract entities from code. |
| github_token (str, optional): Personal access token to access private GitHub repos. |
| If not provided, the method will look for the `GITHUB_OAUTH_TOKEN` environment variable. |
| allow_unauthenticated_clone (bool): If True, attempt to clone without a token when none is provided. |
| If False, raise an error when no token is available. |
| |
| Returns: |
| RepoKnowledgeGraph: The constructed knowledge graph. |
| """ |
| if skip_dirs is None: |
| skip_dirs = [] |
| if model_service_kwargs is None: |
| model_service_kwargs = {} |
|
|
| instance = cls.__new__(cls) |
| instance._initialize(model_service_kwargs=model_service_kwargs, code_index_kwargs=code_index_kwargs) |
|
|
| instance.logger.info(f"Starting knowledge graph build from remote repository: {repo_url}") |
|
|
| |
| token = github_token or os.environ.get('GITHUB_OAUTH_TOKEN') |
|
|
| with tempfile.TemporaryDirectory() as tmpdirname: |
| clone_url = repo_url |
| try: |
| if repo_url.startswith('git@'): |
| |
| clone_url = repo_url.replace(':', '/').split('git@')[-1] |
| clone_url = f'https://{clone_url}' |
|
|
| if token and clone_url.startswith('https://'): |
| encoded_token = urllib.parse.quote(token, safe='') |
| clone_url = clone_url.replace('https://', f'https://{encoded_token}@') |
| elif not token and not allow_unauthenticated_clone: |
| raise ValueError( |
| "GitHub token not provided and unauthenticated clone is disabled. " |
| "Set allow_unauthenticated_clone=True or provide a token." |
| ) |
|
|
| instance.logger.debug(f"Running git clone: {clone_url} -> {tmpdirname}") |
| subprocess.run(['git', 'clone', clone_url, tmpdirname], check=True) |
|
|
| except Exception as e: |
| instance.logger.error(f"Failed to clone repository {repo_url} using URL {clone_url}: {e}") |
| raise |
|
|
| instance.logger.info(f"Repository successfully cloned to: {tmpdirname}") |
|
|
| return cls.from_path( |
| tmpdirname, |
| skip_dirs=skip_dirs, |
| index_nodes=index_nodes, |
| describe_nodes=describe_nodes, |
| extract_entities=extract_entities, |
| model_service_kwargs=model_service_kwargs, |
| code_index_kwargs=code_index_kwargs |
| ) |
|
|
| async def _initial_parse_path_async(self, path: str, skip_dirs: list, index_nodes=True, describe_nodes=True, |
| extract_entities: bool = True): |
| self.logger.info(f"Beginning async parsing of repository at path: {path}") |
| """ |
| Orchestrates the parsing and graph construction process: |
| 1. Reads files and splits into chunks. |
| 2. Extracts entities and relationships. |
| 3. Builds chunk, file, directory, and root nodes. |
| 4. Aggregates entity information. |
| |
| Args: |
| path (str): Root path to parse. |
| skip_dirs (list): Directories to skip. |
| index_nodes (bool): Whether to build code index. |
| describe_nodes (bool): Whether to generate descriptions. |
| extract_entities (bool): Whether to extract entities. |
| """ |
|
|
| |
| level1_node_contents = read_directory_files_recursively( |
| path, skip_dirs=skip_dirs, |
| skip_pattern=r"(?:\.log$|\.json$|(?:^|/)(?:\.git|\.idea|__pycache__|\.cache)(?:/|$)|(?:^|/)(?:changelog|ChangeLog)(?:\.[a-z0-9]+)?$|\.cache$)" |
| ) |
| self.logger.debug(f"Found {len(level1_node_contents)} files to process.") |
| self.logger.info("Chunk nodes creation step started.") |
| chunk_info = await self._create_chunk_nodes( |
| level1_node_contents, extract_entities, describe_nodes, index_nodes, root_path=path |
| ) |
| self.logger.info("Chunk nodes creation step finished.") |
| self.logger.info("File nodes creation step started.") |
| file_info = self._create_file_nodes( |
| chunk_info, level1_node_contents |
| ) |
| self.logger.info("File nodes creation step finished.") |
| self.logger.info("Directory nodes creation step started.") |
| dir_agg = self._create_directory_nodes( |
| file_info |
| ) |
| self.logger.info("Directory nodes creation step finished.") |
| self.logger.info("Aggregating all nodes to root node.") |
| self._aggregate_to_root(dir_agg) |
| self.logger.info("Async parse and node aggregation fully complete.") |
|
|
| async def _create_chunk_nodes(self, level1_node_contents, extract_entities, describe_nodes, index_nodes, root_path=None): |
| self.logger.info(f"Starting chunk node creation for {len(level1_node_contents)} files.") |
| accepted_extensions = {'.py', '.c', '.cpp', '.h', '.hpp', '.java', '.js', '.ts', '.jsx', '.tsx', '.rs', '.html'} |
| chunk_info = {} |
| entity_mapper = EntityChunkMapper() |
| total_chunks = 0 |
|
|
| |
| for file_path in tqdm.tqdm(level1_node_contents, desc="Processing files for chunk nodes"): |
| self.logger.debug(f"Processing file for chunk nodes: {file_path}") |
| full_path = os.path.normpath(file_path) |
| parts = full_path.split(os.sep) |
| _, ext = os.path.splitext(file_path) |
| is_code_file = ext.lower() in accepted_extensions |
|
|
| self.logger.debug(f"Parsing file: {file_path}") |
|
|
| |
| parsed_content = self.code_parser.parse(file_name=file_path, file_content=level1_node_contents[file_path]) |
| self.logger.debug(f"Parsed {len(parsed_content)} chunks from file: {file_path}") |
| total_chunks += len(parsed_content) |
|
|
| |
| if extract_entities and is_code_file: |
| self.logger.debug(f"Extracting entities from code file: {file_path}") |
| try: |
| |
| extraction_file_path = os.path.join(root_path, file_path) if root_path else file_path |
| |
| file_declared_entities, file_called_entities = self.entity_extractor.extract_entities( |
| code=level1_node_contents[file_path], file_name=extraction_file_path) |
| self.logger.debug(f"Extracted {len(file_declared_entities)} declared and {len(file_called_entities)} called entities from file: {file_path}") |
|
|
| chunk_declared_map, chunk_called_map = entity_mapper.map_entities_to_chunks( |
| file_declared_entities, file_called_entities, parsed_content, file_name=file_path) |
| self.logger.debug(f"Mapped entities to {len(parsed_content)} chunks for file: {file_path}") |
| except Exception as e: |
| self.logger.error(f"Error extracting entities from {file_path}: {e}") |
| file_declared_entities, file_called_entities = [], [] |
| chunk_declared_map = {i: [] for i in range(len(parsed_content))} |
| chunk_called_map = {i: [] for i in range(len(parsed_content))} |
| else: |
| self.logger.debug(f"Skipping entity extraction for non-code file: {file_path}") |
| file_declared_entities, file_called_entities = [], [] |
| chunk_declared_map = {i: [] for i in range(len(parsed_content))} |
| chunk_called_map = {i: [] for i in range(len(parsed_content))} |
|
|
| chunk_tasks = [] |
| for i, chunk in enumerate(parsed_content): |
| chunk_id = f'{file_path}_{i}' |
| self.logger.debug(f"Scheduling processing for chunk {chunk_id} of file {file_path}") |
|
|
| async def process_chunk(i=i, chunk=chunk, chunk_id=chunk_id): |
| self.logger.debug(f"Creating chunk node: {chunk_id}") |
| declared_entities = chunk_declared_map.get(i, []) |
| called_entities = chunk_called_map.get(i, []) |
|
|
| |
| |
| temp_alias_map = build_entity_alias_map(self.entities) |
|
|
| for entity in declared_entities: |
| name = entity.get("name") |
| if not name: |
| continue |
|
|
| |
| entity_aliases = entity.get("aliases", []) |
| canonical_name = None |
|
|
| |
| if name in temp_alias_map: |
| canonical_name = temp_alias_map[name] |
| self.logger.debug(f"Entity '{name}' already exists as '{canonical_name}'") |
| else: |
| |
| for alias in entity_aliases: |
| if alias in temp_alias_map: |
| canonical_name = temp_alias_map[alias] |
| self.logger.debug(f"Entity '{name}' matches existing entity '{canonical_name}' via alias '{alias}'") |
| break |
|
|
| |
| if canonical_name: |
| entity_key = canonical_name |
| else: |
| entity_key = name |
| self.logger.debug(f"Registering new declared entity '{name}' in chunk {chunk_id}") |
| self.entities[entity_key] = { |
| "declaring_chunk_ids": [], |
| "calling_chunk_ids": [], |
| "type": [], |
| "dtype": None, |
| "aliases": [] |
| } |
| |
| temp_alias_map[entity_key] = entity_key |
|
|
| if chunk_id not in self.entities[entity_key]["declaring_chunk_ids"]: |
| self.entities[entity_key]["declaring_chunk_ids"].append(chunk_id) |
| entity_type = entity.get("type") |
| if entity_type and entity_type not in self.entities[entity_key]["type"]: |
| self.entities[entity_key]["type"].append(entity_type) |
| dtype = entity.get("dtype") |
| if dtype: |
| self.entities[entity_key]["dtype"] = dtype |
| |
| for alias in [name] + entity_aliases: |
| if alias and alias not in self.entities[entity_key]["aliases"]: |
| self.entities[entity_key]["aliases"].append(alias) |
| temp_alias_map[alias] = entity_key |
| self.logger.debug(f"Declared entity '{name}' registered as '{entity_key}' in chunk {chunk_id} with aliases: {self.entities[entity_key]['aliases']}") |
|
|
|
|
| |
| if describe_nodes: |
| self.logger.info(f"Generating description for chunk {chunk_id}") |
| try: |
| description = await self.model_service.query_async( |
| f'Summarize this {get_language_from_filename(file_path)} code chunk in a few sentences: {chunk}') |
| except Exception as e: |
| self.logger.error(f"Error generating description for chunk {chunk_id}: {e}") |
| description = '' |
| else: |
| self.logger.debug(f"No description requested for chunk {chunk_id}") |
| description = '' |
|
|
| chunk_node = ChunkNode( |
| id=chunk_id, |
| name=chunk_id, |
| path=file_path, |
| content=chunk, |
| order_in_file=i, |
| called_entities=called_entities, |
| declared_entities=declared_entities, |
| language=get_language_from_filename(file_path), |
| description=description, |
| ) |
| self.logger.debug(f"Chunk node created: {chunk_id}") |
|
|
| |
| |
| chunk_node.embedding = None |
| return (chunk_id, chunk_node, declared_entities, called_entities) |
|
|
| chunk_tasks.append(process_chunk()) |
|
|
| chunk_results = await asyncio.gather(*chunk_tasks) |
| self.logger.debug(f"Finished processing {len(chunk_results)} chunks for file {file_path}.") |
| chunk_info[file_path] = { |
| 'chunk_results': chunk_results, |
| 'file_declared_entities': file_declared_entities, |
| 'file_called_entities': file_called_entities |
| } |
|
|
| |
| self.logger.info(f"Created {total_chunks} chunk nodes from {len(level1_node_contents)} files") |
|
|
| |
| self.logger.info("Starting second pass: resolving called entities using alias map...") |
| alias_map = build_entity_alias_map(self.entities) |
| self.logger.info(f"Built alias map with {len(alias_map)} entries for resolution") |
|
|
| resolved_count = 0 |
| for file_path, file_data in tqdm.tqdm(chunk_info.items(), desc="Resolving called entities"): |
| chunk_results = file_data['chunk_results'] |
| for chunk_id, chunk_node, declared_entities, called_entities in chunk_results: |
| for called_name in called_entities: |
| |
| if not called_name or not called_name.strip(): |
| continue |
|
|
| |
| resolved_name = resolve_entity_call(called_name, alias_map) |
|
|
| |
| if resolved_name: |
| entity_key = resolved_name |
| elif called_name in alias_map: |
| |
| entity_key = alias_map[called_name] |
| else: |
| |
| entity_key = called_name |
|
|
| if entity_key not in self.entities: |
| self.logger.debug(f"Registering new called entity '{entity_key}' (called as '{called_name}') in chunk {chunk_id}") |
| self.entities[entity_key] = { |
| "declaring_chunk_ids": [], |
| "calling_chunk_ids": [], |
| "type": [], |
| "dtype": None, |
| "aliases": [] |
| } |
| |
| if called_name != entity_key: |
| self.entities[entity_key]["aliases"].append(called_name) |
| alias_map[called_name] = entity_key |
|
|
| if chunk_id not in self.entities[entity_key]["calling_chunk_ids"]: |
| self.entities[entity_key]["calling_chunk_ids"].append(chunk_id) |
|
|
| if resolved_name and resolved_name != called_name: |
| resolved_count += 1 |
| self.logger.debug(f"Called entity '{called_name}' resolved to '{entity_key}' in chunk {chunk_id}") |
|
|
| self.logger.info(f"Resolved {resolved_count} entity calls to existing declarations via aliases") |
| self.logger.info("All chunk nodes have been created for all files.") |
| return chunk_info |
|
|
| def _create_file_nodes(self, chunk_info, level1_node_contents): |
| self.logger.info("Starting file node creation.") |
| """ |
| For each file, aggregate chunk information and create FileNode objects. |
| This method remains mostly the same. |
| """ |
|
|
| def merge_entities(target, source): |
| |
| existing = set((e.get('name'), e.get('type')) for e in target) |
| for e in source: |
| k = (e.get('name'), e.get('type')) |
| if k not in existing: |
| target.append(e) |
| existing.add(k) |
|
|
| def merge_called_entities(target, source): |
| |
| existing = set(target) |
| for e in source: |
| if e not in existing: |
| target.append(e) |
| existing.add(e) |
|
|
| file_info = {} |
| for file_path, file_data in tqdm.tqdm(chunk_info.items(), desc="Creating file nodes"): |
| self.logger.info(f"Creating file node for: {file_path}") |
| parts = os.path.normpath(file_path).split(os.sep) |
|
|
| |
| chunk_results = file_data['chunk_results'] |
| file_declared_entities = list(file_data['file_declared_entities']) |
| file_called_entities = list(file_data['file_called_entities']) |
| chunk_ids = [] |
|
|
| for chunk_id, chunk_node, declared_entities, called_entities in chunk_results: |
| self.logger.info(f"Adding chunk node {chunk_id} to graph for file {file_path}") |
| self.graph.add_node(chunk_id, data=chunk_node, level=2) |
| chunk_ids.append(chunk_id) |
| |
| |
|
|
| file_node = FileNode( |
| id=file_path, |
| name=parts[-1], |
| path=file_path, |
| node_type='file', |
| content=level1_node_contents[file_path], |
| declared_entities=file_declared_entities, |
| called_entities=file_called_entities, |
| language=get_language_from_filename(file_path), |
| ) |
|
|
| self.logger.debug(f"Adding file node {file_path} to graph.") |
| self.graph.add_node(file_path, data=file_node, level=1) |
| for chunk_id in chunk_ids: |
| self.graph.add_edge(file_path, chunk_id, relation='contains') |
|
|
| file_info[file_path] = { |
| 'declared_entities': file_declared_entities, |
| 'called_entities': file_called_entities, |
| 'chunk_ids': chunk_ids, |
| 'parts': parts, |
| } |
| self.logger.info(f"File node {file_path} added to graph with {len(chunk_ids)} chunks.") |
|
|
| self.logger.info("All file nodes have been created.") |
| return file_info |
|
|
| def _create_directory_nodes(self, file_info): |
| self.logger.info("Starting directory node creation.") |
| """ |
| For each directory, aggregate file information and create DirectoryNode objects. |
| |
| Args: |
| file_info (dict): Mapping file_path -> file info dict. |
| |
| Returns: |
| dict: Mapping dir_path -> aggregated entity info. |
| """ |
|
|
| def merge_entities(target, source): |
| |
| existing = set((e.get('name'), e.get('type')) for e in target) |
| for e in source: |
| k = (e.get('name'), e.get('type')) |
| if k not in existing: |
| target.append(e) |
| existing.add(k) |
|
|
| def merge_called_entities(target, source): |
| |
| existing = set(target) |
| for e in source: |
| if e not in existing: |
| target.append(e) |
| existing.add(e) |
|
|
| dir_agg = {} |
| for file_path, info in tqdm.tqdm(file_info.items(), desc="Creating directory nodes"): |
| self.logger.info(f"Processing directory nodes for file: {file_path}") |
| parts = os.path.normpath(file_path).split(os.sep) |
| file_declared_entities = info['declared_entities'] |
| file_called_entities = info['called_entities'] |
| current_parent = 'root' |
| path_accum = '' |
| for part in parts[:-1]: |
| path_accum = os.path.join(path_accum, part) if path_accum else part |
| if path_accum not in self.graph: |
| self.logger.info(f"Adding new directory node: {path_accum}") |
| dir_node = DirectoryNode(id=path_accum, name=part, path=path_accum) |
| self.graph.add_node(path_accum, data=dir_node, level=1) |
| self.graph.add_edge(current_parent, path_accum, relation='contains') |
| if path_accum not in dir_agg: |
| dir_agg[path_accum] = {'declared_entities': [], 'called_entities': []} |
| merge_entities(dir_agg[path_accum]['declared_entities'], file_declared_entities) |
| merge_called_entities(dir_agg[path_accum]['called_entities'], file_called_entities) |
| current_parent = path_accum |
| |
| self.graph.add_edge(current_parent, file_path, relation='contains') |
| self.logger.info("All directory nodes created.") |
| return dir_agg |
|
|
| def _aggregate_to_root(self, dir_agg): |
| self.logger.info("Aggregating directory information to root node.") |
| """ |
| Aggregate all directory entity information to the root node. |
| |
| Args: |
| dir_agg (dict): Mapping dir_path -> aggregated entity info. |
| """ |
|
|
| def merge_entities(target, source): |
| |
| existing = set((e.get('name'), e.get('type')) for e in target) |
| for e in source: |
| k = (e.get('name'), e.get('type')) |
| if k not in existing: |
| target.append(e) |
| existing.add(k) |
|
|
| def merge_called_entities(target, source): |
| |
| existing = set(target) |
| for e in source: |
| if e not in existing: |
| target.append(e) |
| existing.add(e) |
|
|
| root_node = Node(id='root', name='root', node_type='root') |
| self.graph.add_node('root', data=root_node, level=0) |
| root_declared_entities = [] |
| root_called_entities = [] |
| for dir_path, agg in tqdm.tqdm(dir_agg.items(), desc="Aggregating to root"): |
| node = self.graph.nodes[dir_path]['data'] |
| if not hasattr(node, 'declared_entities'): |
| node.declared_entities = [] |
| if not hasattr(node, 'called_entities'): |
| node.called_entities = [] |
| merge_entities(node.declared_entities, agg['declared_entities']) |
| merge_called_entities(node.called_entities, agg['called_entities']) |
| merge_entities(root_declared_entities, agg['declared_entities']) |
| merge_called_entities(root_called_entities, agg['called_entities']) |
| if not hasattr(root_node, 'declared_entities'): |
| root_node.declared_entities = [] |
| if not hasattr(root_node, 'called_entities'): |
| root_node.called_entities = [] |
| merge_entities(root_node.declared_entities, root_declared_entities) |
| merge_called_entities(root_node.called_entities, root_called_entities) |
| self.logger.info("Aggregation to root node complete.") |
|
|
| def _build_relationships(self): |
| self.logger.info("Building relationships between chunk nodes based on entities.") |
| """ |
| Build relationships between chunk nodes and entity nodes based on self.entities. |
| For each entity in self.entities: |
| 1. Create an EntityNode with entity_name as the id |
| 2. Create edges from declaring chunks to entity node (declares relationship) |
| 3. Create edges from entity node to calling chunks (called_by relationship) |
| 4. Resolve called entity names using aliases for better matching |
| """ |
| from .Node import EntityNode |
| edges_created = 0 |
| entity_nodes_created = 0 |
| |
| |
| self.logger.info("Building entity alias map for call resolution...") |
| alias_map = build_entity_alias_map(self.entities) |
| self.logger.info(f"Built alias map with {len(alias_map)} entries") |
|
|
| |
| for entity_name, info in tqdm.tqdm(self.entities.items(), desc="Creating entity nodes"): |
| |
| entity_types = info.get('type', []) |
| entity_type = entity_types[0] if entity_types else '' |
| declaring_chunks = info.get('declaring_chunk_ids', []) |
| calling_chunks = info.get('calling_chunk_ids', []) |
| aliases = info.get('aliases', []) |
|
|
| |
| entity_node = EntityNode( |
| id=entity_name, |
| name=entity_name, |
| entity_type=entity_type, |
| declaring_chunk_ids=declaring_chunks, |
| calling_chunk_ids=calling_chunks, |
| aliases=aliases |
| ) |
| |
| |
| self.graph.add_node(entity_name, data=entity_node, level=3) |
| entity_nodes_created += 1 |
| |
| |
| if aliases: |
| self.logger.debug(f"Created EntityNode '{entity_name}' with aliases: {aliases}") |
|
|
| |
| for declarer_id in declaring_chunks: |
| if declarer_id in self.graph: |
| self.graph.add_edge(declarer_id, entity_name, relation='declares') |
| edges_created += 1 |
| |
| |
| for caller_id in calling_chunks: |
| if caller_id in self.graph and caller_id not in declaring_chunks: |
| self.graph.add_edge(entity_name, caller_id, relation='called_by') |
| edges_created += 1 |
|
|
| |
| self.logger.info("Resolving entity calls using alias matching...") |
| resolved_calls = 0 |
|
|
| for entity_name, info in tqdm.tqdm(self.entities.items(), desc="Resolving entity calls"): |
| |
| if info.get('declaring_chunk_ids'): |
| continue |
|
|
| |
| resolved_name = resolve_entity_call(entity_name, alias_map) |
|
|
| if resolved_name and resolved_name != entity_name: |
| |
| calling_chunks = info.get('calling_chunk_ids', []) |
|
|
| if resolved_name in self.entities: |
| for caller_id in calling_chunks: |
| if caller_id in self.graph: |
| |
| if not self.graph.has_edge(resolved_name, caller_id): |
| self.graph.add_edge(resolved_name, caller_id, relation='called_by') |
| edges_created += 1 |
| resolved_calls += 1 |
| self.logger.debug(f"Resolved call: '{entity_name}' -> '{resolved_name}' in chunk {caller_id}") |
|
|
| self.logger.info(f"_build_relationships: Created {entity_nodes_created} entity nodes, " |
| f"{edges_created} edges, and resolved {resolved_calls} entity calls using aliases.") |
|
|
| def get_entity_by_alias(self, alias: str) -> Optional[str]: |
| """ |
| Get the canonical entity name for a given alias. |
| |
| Args: |
| alias: An alias of an entity (e.g., 'MyClass' or 'module.MyClass') |
| |
| Returns: |
| Canonical entity name if found, None otherwise |
| """ |
| alias_map = build_entity_alias_map(self.entities) |
| return alias_map.get(alias) |
|
|
| def resolve_entity_references(self) -> Dict[str, List[str]]: |
| """ |
| Resolve all entity references in the knowledge graph using aliases. |
| Returns a mapping of unresolved entity calls to their potential matches. |
| |
| Returns: |
| Dictionary mapping called entity names to list of potential canonical matches |
| """ |
| alias_map = build_entity_alias_map(self.entities) |
| resolutions = {} |
|
|
| for entity_name, info in self.entities.items(): |
| |
| if not info.get('declaring_chunk_ids') and info.get('calling_chunk_ids'): |
| resolved = resolve_entity_call(entity_name, alias_map) |
| if resolved: |
| resolutions[entity_name] = resolved |
|
|
| return resolutions |
|
|
| def print_tree(self, max_depth=None, start_node_id='root', level=0, prefix=""): |
| """ |
| Print the repository tree structure using the graph with 'contains' edges. |
| |
| Args: |
| max_depth (int, optional): Maximum depth to print. None = unlimited. |
| start_node_id (str): ID of the node to start from. Default is 'root'. |
| level (int): Internal use only (used for recursion). |
| prefix (str): Internal use only (used for formatting output). |
| """ |
| if max_depth is not None and level > max_depth: |
| self.logger.debug(f"Max depth {max_depth} reached at node {start_node_id}.") |
| return |
|
|
| if start_node_id not in self.graph: |
| self.logger.warning(f"Start node '{start_node_id}' not found in graph.") |
| return |
|
|
| try: |
| node_data = self[start_node_id] |
| except KeyError as e: |
| self.logger.error(f"KeyError when accessing node {start_node_id}: {e}") |
| self.logger.error(f"Available node attributes: {list(self.graph.nodes[start_node_id].keys())}") |
| |
| if 'data' not in self.graph.nodes[start_node_id]: |
| self.logger.warning(f"Node {start_node_id} has no 'data' attribute, using node itself") |
| |
| if start_node_id == 'root': |
| |
| node_data = Node(id='root', name='root', node_type='root') |
| |
| self.graph.nodes[start_node_id]['data'] = node_data |
| else: |
| |
| name = start_node_id.split('/')[-1] if '/' in start_node_id else start_node_id |
| if '_' in start_node_id and start_node_id.split('_')[-1].isdigit(): |
| |
| node_data = ChunkNode(id=start_node_id, name=name, node_type='chunk') |
| elif '.' in name: |
| |
| node_data = FileNode(id=start_node_id, name=name, node_type='file', path=start_node_id) |
| else: |
| |
| node_data = DirectoryNode(id=start_node_id, name=name, node_type='directory', |
| path=start_node_id) |
| |
| self.graph.nodes[start_node_id]['data'] = node_data |
| return |
|
|
| |
| if node_data.node_type == 'file': |
| node_symbol = "📄" |
| elif node_data.node_type == 'chunk': |
| node_symbol = "📝" |
| elif node_data.node_type == 'root': |
| node_symbol = "📁" |
| elif node_data.node_type == 'directory': |
| node_symbol = "📂" |
| else: |
| node_symbol = "📦" |
|
|
| if level == 0: |
| print(f"{node_symbol} {node_data.name} ({node_data.node_type})") |
| else: |
| print(f"{prefix}└── {node_symbol} {node_data.name} ({node_data.node_type})") |
|
|
| |
| children = [ |
| child for child in self.graph.successors(start_node_id) |
| if self.graph.edges[start_node_id, child].get('relation') == 'contains' |
| ] |
|
|
| child_count = len(children) |
| for i, child_id in enumerate(children): |
| is_last = i == child_count - 1 |
| new_prefix = prefix + (" " if is_last else "│ ") |
| self.print_tree(max_depth, start_node_id=child_id, level=level + 1, prefix=new_prefix) |
|
|
| def to_dict(self): |
| self.logger.info("Serializing graph to dictionary.") |
| from .Node import EntityNode |
| graph_data = { |
| 'nodes': [], |
| 'edges': [] |
| } |
|
|
| for node_id, node_attrs in tqdm.tqdm(self.graph.nodes(data=True), desc="Serializing nodes"): |
| if 'data' not in node_attrs: |
| self.logger.warning(f"Node {node_id} has no 'data' attribute, skipping in serialization") |
| continue |
|
|
| node = node_attrs['data'] |
| node_dict = { |
| 'id': node.id or node_id, |
| 'class': node.__class__.__name__, |
| 'data': { |
| 'id': node.id or node_id, |
| 'name': node.name, |
| 'node_type': node.node_type, |
| 'description': getattr(node, 'description', ''), |
| 'declared_entities': list(getattr(node, 'declared_entities', [])), |
| 'called_entities': list(getattr(node, 'called_entities', [])), |
| } |
| } |
|
|
| |
| if isinstance(node, FileNode): |
| node_dict['data']['path'] = node.path |
| node_dict['data']['content'] = node.content |
| node_dict['data']['language'] = getattr(node, 'language', '') |
|
|
| |
| if isinstance(node, ChunkNode): |
| node_dict['data']['order_in_file'] = getattr(node, 'order_in_file', 0) |
| node_dict['data']['embedding'] = getattr(node, 'embedding', None) |
| |
| |
| if isinstance(node, EntityNode): |
| node_dict['data']['entity_type'] = getattr(node, 'entity_type', '') |
| node_dict['data']['declaring_chunk_ids'] = list(getattr(node, 'declaring_chunk_ids', [])) |
| node_dict['data']['calling_chunk_ids'] = list(getattr(node, 'calling_chunk_ids', [])) |
| node_dict['data']['aliases'] = list(getattr(node, 'aliases', [])) |
|
|
| graph_data['nodes'].append(node_dict) |
|
|
| for u, v, attrs in tqdm.tqdm(self.graph.edges(data=True), desc="Serializing edges"): |
| edge_data = { |
| 'source': u, |
| 'target': v, |
| 'relation': attrs.get('relation', '') |
| } |
| if 'entities' in attrs: |
| edge_data['entities'] = list(attrs['entities']) |
| graph_data['edges'].append(edge_data) |
|
|
| self.logger.info("Serialization complete.") |
| return graph_data |
|
|
| @classmethod |
| def from_dict(cls, data_dict, index_nodes: bool = True, use_embed: bool = True, |
| model_service_kwargs: Optional[dict] = None, code_index_kwargs: Optional[dict] = None): |
| |
| instance = cls.__new__(cls) |
| instance._initialize(model_service_kwargs=model_service_kwargs, code_index_kwargs=code_index_kwargs) |
|
|
| instance.logger.info("Deserializing graph from dictionary.") |
|
|
| |
| node_classes = { |
| 'Node': Node, |
| 'FileNode': FileNode, |
| 'ChunkNode': ChunkNode, |
| 'DirectoryNode': DirectoryNode, |
| 'EntityNode': EntityNode, |
| } |
|
|
| |
| root_found = any(node_data['id'] == 'root' for node_data in data_dict['nodes']) |
| if not root_found: |
| instance.logger.warning("Root node not found in the data, creating one") |
| root_node = Node(id='root', name='root', node_type='root') |
| instance.graph.add_node('root', data=root_node, level=0) |
|
|
| |
| for node_data in tqdm.tqdm(data_dict['nodes'], desc="Rebuilding nodes"): |
| cls_name = node_data['class'] |
| node_cls = node_classes.get(cls_name, Node) |
| kwargs = node_data['data'] |
|
|
| |
| if not kwargs.get('id'): |
| kwargs['id'] = node_data['id'] |
|
|
| |
| kwargs['declared_entities'] = list(kwargs.get('declared_entities', [])) |
| kwargs['called_entities'] = list(kwargs.get('called_entities', [])) |
|
|
| |
| if node_cls in (FileNode, ChunkNode): |
| kwargs.setdefault('path', '') |
| kwargs.setdefault('content', '') |
| kwargs.setdefault('language', '') |
| if node_cls == ChunkNode: |
| kwargs.setdefault('order_in_file', 0) |
| kwargs.setdefault('embedding', []) |
| |
| if node_cls == EntityNode: |
| kwargs.setdefault('entity_type', '') |
| kwargs.setdefault('declaring_chunk_ids', []) |
| kwargs.setdefault('calling_chunk_ids', []) |
| kwargs.setdefault('aliases', []) |
|
|
| node_instance = node_cls(**kwargs) |
| instance.graph.add_node(node_data['id'], data=node_instance, level=instance._infer_level(node_instance)) |
|
|
| |
| for edge in tqdm.tqdm(data_dict['edges'], desc="Rebuilding edges"): |
| source = edge['source'] |
| target = edge['target'] |
| if source in instance.graph and target in instance.graph: |
| edge_kwargs = {'relation': edge.get('relation', '')} |
| if 'entities' in edge: |
| edge_kwargs['entities'] = list(edge['entities']) |
| instance.graph.add_edge(source, target, **edge_kwargs) |
| else: |
| instance.logger.warning(f"Cannot add edge {source} -> {target}, nodes don't exist") |
|
|
| |
| instance.entities = {} |
| for node_id, node_attrs in tqdm.tqdm(instance.graph.nodes(data=True), desc="Rebuilding entities"): |
| node = node_attrs['data'] |
| declared_entities = getattr(node, 'declared_entities', []) |
| called_entities = getattr(node, 'called_entities', []) |
| for entity in declared_entities: |
| if isinstance(entity, dict): |
| name = entity.get('name') |
| else: |
| name = entity |
| if not name: |
| continue |
| if name not in instance.entities: |
| instance.entities[name] = { |
| "declaring_chunk_ids": [], |
| "calling_chunk_ids": [], |
| "type": [], |
| "dtype": None |
| } |
| |
| if node_id not in instance.entities[name]["declaring_chunk_ids"]: |
| if node_id in instance.graph and isinstance(instance.graph.nodes[node_id]["data"], ChunkNode): |
| instance.entities[name]["declaring_chunk_ids"].append(node_id) |
| if isinstance(entity, dict): |
| entity_type = entity.get("type") |
| if entity_type and entity_type not in instance.entities[name]["type"]: |
| instance.entities[name]["type"].append(entity_type) |
| dtype = entity.get("dtype") |
| if dtype: |
| instance.entities[name]["dtype"] = dtype |
| for called_name in called_entities: |
| if not called_name: |
| continue |
| if called_name not in instance.entities: |
| instance.entities[called_name] = { |
| "declaring_chunk_ids": [], |
| "calling_chunk_ids": [], |
| "type": [], |
| "dtype": None |
| } |
| if node_id not in instance.entities[called_name]["calling_chunk_ids"]: |
| if node_id in instance.graph and isinstance(instance.graph.nodes[node_id]["data"], ChunkNode): |
| instance.entities[called_name]["calling_chunk_ids"].append(node_id) |
|
|
| if index_nodes: |
| instance.logger.info("Building code index after deserialization.") |
| |
| code_idx_kwargs = code_index_kwargs or {} |
| if 'use_embed' not in code_idx_kwargs: |
| code_idx_kwargs['use_embed'] = use_embed |
| instance.code_index = CodeIndex(list(instance), model_service=instance.model_service, **code_idx_kwargs) |
|
|
| instance.logger.info("Deserialization complete.") |
| return instance |
|
|
| def _infer_level(self, node): |
| """Infer the level of a node based on its type""" |
| if node.node_type == 'root': |
| return 0 |
| elif node.node_type in ('file', 'directory'): |
| return 1 |
| elif node.node_type == 'chunk': |
| return 2 |
| return 1 |
|
|
| def save_graph_to_file(self, filepath: str): |
| self.logger.info(f"Saving graph to file: {filepath}") |
| with open(filepath, 'w') as f: |
| json.dump(self.to_dict(), f, indent=2) |
| self.logger.info("Graph saved successfully.") |
|
|
| @classmethod |
| def load_graph_from_file(cls, filepath: str, index_nodes=True, use_embed: bool = True, |
| model_service_kwargs: Optional[dict] = None, code_index_kwargs: Optional[dict] = None): |
| if model_service_kwargs is None: |
| model_service_kwargs = {} |
| with open(filepath, 'r') as f: |
| data = json.load(f) |
| logging.getLogger(LOGGER_NAME).info(f"Loaded graph data from file: {filepath}") |
| return cls.from_dict(data, use_embed=use_embed, index_nodes=index_nodes, |
| model_service_kwargs=model_service_kwargs, code_index_kwargs=code_index_kwargs) |
|
|
| def to_hf_dataset( |
| self, |
| repo_id: str, |
| save_embeddings: bool = True, |
| private: bool = False, |
| token: Optional[str] = None, |
| commit_message: Optional[str] = None, |
| ): |
| """ |
| Save the knowledge graph to a HuggingFace dataset on the Hub. |
| |
| The graph is serialized into two splits: |
| - 'nodes': Contains all node data |
| - 'edges': Contains all edge relationships |
| |
| Args: |
| repo_id (str): The HuggingFace dataset repository ID (e.g., 'username/dataset-name') |
| save_embeddings (bool): If True, saves embedding vectors for chunk nodes. |
| If False, embeddings are excluded to reduce dataset size. |
| private (bool): Whether the dataset should be private. Defaults to False. |
| token (str, optional): HuggingFace API token. If not provided, uses the token |
| from huggingface_hub login or HF_TOKEN environment variable. |
| commit_message (str, optional): Custom commit message for the upload. |
| |
| Returns: |
| str: URL of the uploaded dataset |
| """ |
| try: |
| from datasets import Dataset, DatasetDict |
| from huggingface_hub import HfApi |
| except ImportError: |
| raise ImportError( |
| "huggingface_hub and datasets are required for HuggingFace integration. " |
| "Install them with: pip install huggingface_hub datasets" |
| ) |
|
|
| self.logger.info(f"Preparing to save knowledge graph to HuggingFace dataset: {repo_id}") |
| self.logger.info(f"save_embeddings={save_embeddings}") |
|
|
| |
| nodes_data = [] |
| for node_id, node_attrs in tqdm.tqdm(self.graph.nodes(data=True), desc="Serializing nodes for HF dataset"): |
| if 'data' not in node_attrs: |
| self.logger.warning(f"Node {node_id} has no 'data' attribute, skipping") |
| continue |
|
|
| node = node_attrs['data'] |
| node_record = { |
| 'node_id': node.id or node_id, |
| 'node_class': node.__class__.__name__, |
| 'name': node.name, |
| 'node_type': node.node_type, |
| 'description': getattr(node, 'description', '') or '', |
| 'declared_entities': json.dumps(list(getattr(node, 'declared_entities', []))), |
| 'called_entities': json.dumps(list(getattr(node, 'called_entities', []))), |
| } |
|
|
| |
| if isinstance(node, FileNode): |
| node_record['path'] = node.path |
| node_record['content'] = node.content |
| node_record['language'] = getattr(node, 'language', '') |
| else: |
| node_record['path'] = '' |
| node_record['content'] = '' |
| node_record['language'] = '' |
|
|
| |
| if isinstance(node, ChunkNode): |
| node_record['order_in_file'] = getattr(node, 'order_in_file', 0) |
| if save_embeddings: |
| embedding = getattr(node, 'embedding', None) |
| node_record['embedding'] = json.dumps(embedding if embedding is not None else []) |
| else: |
| node_record['embedding'] = json.dumps([]) |
| else: |
| node_record['order_in_file'] = -1 |
| node_record['embedding'] = json.dumps([]) |
|
|
| |
| if isinstance(node, EntityNode): |
| node_record['entity_type'] = getattr(node, 'entity_type', '') |
| node_record['declaring_chunk_ids'] = json.dumps(list(getattr(node, 'declaring_chunk_ids', []))) |
| node_record['calling_chunk_ids'] = json.dumps(list(getattr(node, 'calling_chunk_ids', []))) |
| node_record['aliases'] = json.dumps(list(getattr(node, 'aliases', []))) |
| else: |
| node_record['entity_type'] = '' |
| node_record['declaring_chunk_ids'] = json.dumps([]) |
| node_record['calling_chunk_ids'] = json.dumps([]) |
| node_record['aliases'] = json.dumps([]) |
|
|
| nodes_data.append(node_record) |
|
|
| |
| edges_data = [] |
| for source, target, attrs in tqdm.tqdm(self.graph.edges(data=True), desc="Serializing edges for HF dataset"): |
| edge_record = { |
| 'source': source, |
| 'target': target, |
| 'relation': attrs.get('relation', ''), |
| 'entities': json.dumps(list(attrs.get('entities', []))) if 'entities' in attrs else json.dumps([]) |
| } |
| edges_data.append(edge_record) |
|
|
| |
| nodes_dataset = Dataset.from_list(nodes_data) |
| edges_dataset = Dataset.from_list(edges_data) |
|
|
| self.logger.info(f"Created dataset with {len(nodes_data)} nodes and {len(edges_data)} edges") |
|
|
| |
| |
| if commit_message is None: |
| base_commit_message = f"Upload knowledge graph ({len(nodes_data)} nodes, {len(edges_data)} edges)" |
| if not save_embeddings: |
| base_commit_message += " [embeddings excluded]" |
| else: |
| base_commit_message = commit_message |
|
|
| self.logger.info(f"Pushing nodes dataset to HuggingFace Hub: {repo_id}") |
| nodes_dataset.push_to_hub( |
| repo_id=repo_id, |
| config_name="nodes", |
| private=private, |
| token=token, |
| commit_message=f"{base_commit_message} - nodes" |
| ) |
|
|
| self.logger.info(f"Pushing edges dataset to HuggingFace Hub: {repo_id}") |
| edges_dataset.push_to_hub( |
| repo_id=repo_id, |
| config_name="edges", |
| private=private, |
| token=token, |
| commit_message=f"{base_commit_message} - edges" |
| ) |
|
|
| url = f"https://huggingface.co/datasets/{repo_id}" |
| self.logger.info(f"Dataset successfully uploaded to: {url}") |
| return url |
|
|
| @classmethod |
| def from_hf_dataset( |
| cls, |
| repo_id: str, |
| index_nodes: bool = True, |
| use_embed: bool = True, |
| model_service_kwargs: Optional[dict] = None, |
| code_index_kwargs: Optional[dict] = None, |
| token: Optional[str] = None, |
| revision: Optional[str] = None, |
| ): |
| """ |
| Load a knowledge graph from a HuggingFace dataset on the Hub. |
| |
| Args: |
| repo_id (str): The HuggingFace dataset repository ID (e.g., 'username/dataset-name') |
| index_nodes (bool): Whether to build a code index after loading. Defaults to True. |
| use_embed (bool): Whether to use existing embeddings from the dataset. Defaults to True. |
| model_service_kwargs (dict, optional): Arguments for the model service. |
| code_index_kwargs (dict, optional): Arguments for the code index. |
| token (str, optional): HuggingFace API token for private datasets. |
| revision (str, optional): Git revision (branch, tag, or commit) to load from. |
| |
| Returns: |
| RepoKnowledgeGraph: The loaded knowledge graph instance. |
| """ |
| try: |
| from datasets import load_dataset |
| except ImportError: |
| raise ImportError( |
| "datasets library is required for HuggingFace integration. " |
| "Install it with: pip install datasets" |
| ) |
|
|
| if model_service_kwargs is None: |
| model_service_kwargs = {} |
|
|
| logger = logging.getLogger(LOGGER_NAME) |
| logger.info(f"Loading knowledge graph from HuggingFace dataset: {repo_id}") |
|
|
| |
| logger.info("Loading nodes config...") |
| nodes_dataset = load_dataset(repo_id, name="nodes", token=token, revision=revision) |
| logger.info("Loading edges config...") |
| edges_dataset = load_dataset(repo_id, name="edges", token=token, revision=revision) |
|
|
| |
| nodes_data = nodes_dataset['train'] |
| edges_data = edges_dataset['train'] |
|
|
| logger.info(f"Loaded {len(nodes_data)} nodes and {len(edges_data)} edges from dataset") |
|
|
| |
| graph_data = { |
| 'nodes': [], |
| 'edges': [] |
| } |
|
|
| |
| for record in tqdm.tqdm(nodes_data, desc="Reconstructing nodes from HF dataset"): |
| node_dict = { |
| 'id': record['node_id'], |
| 'class': record['node_class'], |
| 'data': { |
| 'id': record['node_id'], |
| 'name': record['name'], |
| 'node_type': record['node_type'], |
| 'description': record['description'], |
| 'declared_entities': json.loads(record['declared_entities']), |
| 'called_entities': json.loads(record['called_entities']), |
| } |
| } |
|
|
| |
| if record['node_class'] in ('FileNode', 'ChunkNode'): |
| node_dict['data']['path'] = record['path'] |
| node_dict['data']['content'] = record['content'] |
| node_dict['data']['language'] = record['language'] |
|
|
| |
| if record['node_class'] == 'ChunkNode': |
| node_dict['data']['order_in_file'] = record['order_in_file'] |
| embedding = json.loads(record['embedding']) |
| |
| if use_embed and embedding: |
| node_dict['data']['embedding'] = embedding |
| else: |
| node_dict['data']['embedding'] = [] |
|
|
| |
| if record['node_class'] == 'EntityNode': |
| node_dict['data']['entity_type'] = record['entity_type'] |
| node_dict['data']['declaring_chunk_ids'] = json.loads(record['declaring_chunk_ids']) |
| node_dict['data']['calling_chunk_ids'] = json.loads(record['calling_chunk_ids']) |
| node_dict['data']['aliases'] = json.loads(record['aliases']) |
|
|
| graph_data['nodes'].append(node_dict) |
|
|
| |
| for record in tqdm.tqdm(edges_data, desc="Reconstructing edges from HF dataset"): |
| edge_dict = { |
| 'source': record['source'], |
| 'target': record['target'], |
| 'relation': record['relation'], |
| } |
| entities = json.loads(record['entities']) |
| if entities: |
| edge_dict['entities'] = entities |
|
|
| graph_data['edges'].append(edge_dict) |
|
|
| logger.info("Dataset reconstruction complete, building graph...") |
|
|
| |
| return cls.from_dict( |
| graph_data, |
| index_nodes=index_nodes, |
| use_embed=use_embed, |
| model_service_kwargs=model_service_kwargs, |
| code_index_kwargs=code_index_kwargs |
| ) |
|
|
| def get_neighbors(self, node_id): |
| self.logger.debug(f"Getting neighbors for node: {node_id}") |
| |
| neighbors = set() |
| for n in self.graph.successors(node_id): |
| neighbors.add(n) |
| for n in self.graph.predecessors(node_id): |
| neighbors.add(n) |
| |
| for u, v in self.graph.edges(node_id): |
| if u == node_id: |
| neighbors.add(v) |
| else: |
| neighbors.add(u) |
| for u, v in self.graph.in_edges(node_id): |
| if v == node_id: |
| neighbors.add(u) |
| else: |
| neighbors.add(v) |
| return [self.graph.nodes[n]['data'] for n in neighbors if 'data' in self.graph.nodes[n]] |
|
|
| def get_previous_chunk(self, node_id: str) -> ChunkNode: |
| self.logger.debug(f"Getting previous chunk for node: {node_id}") |
| node = self[node_id] |
| |
| if not isinstance(node, ChunkNode): |
| raise Exception(f'Cannot get previous chunk on node of type {type(node)}') |
|
|
| if node.order_in_file == 0: |
| self.logger.warning(f'Cannot get previous chunk for first node') |
| return None |
|
|
| file_path = node.path |
| previous_chunk_id = f'{file_path}_{node.order_in_file - 1}' |
|
|
| if previous_chunk_id not in self.graph: |
| raise Exception(f'Previous chunk {previous_chunk_id} not found in graph') |
|
|
| previous_chunk = self[previous_chunk_id] |
| return previous_chunk |
|
|
| def get_next_chunk(self, node_id: str) -> ChunkNode: |
| self.logger.debug(f"Getting next chunk for node: {node_id}") |
| node = self[node_id] |
| |
| if not isinstance(node, ChunkNode): |
| raise Exception(f'Cannot get previous chunk on node of type {type(node)}') |
|
|
| file_path = node.path |
| next_chunk_id = f'{file_path}_{node.order_in_file + 1}' |
|
|
| if next_chunk_id not in self.graph: |
| self.logger.warning(f'Next chunk {next_chunk_id} not found in graph, it might be the last chunk') |
| return None |
| previous_chunk = self[next_chunk_id] |
| return previous_chunk |
|
|
| def get_all_chunks(self) -> List[ChunkNode]: |
| self.logger.debug("Getting all chunk nodes.") |
| chunk_nodes = [] |
| for node in self: |
| if isinstance(node, ChunkNode): |
| chunk_nodes.append(node) |
| return chunk_nodes |
|
|
| def get_all_files(self) -> List[FileNode]: |
| self.logger.debug("Getting all file nodes.") |
| """ |
| Get all FileNodes in the knowledge graph. |
| |
| Returns: |
| List[FileNode]: A list of FileNodes in the graph. |
| """ |
| file_nodes = [] |
| for node in self.graph.nodes(data=True): |
| node_data = node[1]['data'] |
| |
| if isinstance(node_data, FileNode) and node_data.node_type == 'file': |
| file_nodes.append(node_data) |
| return file_nodes |
|
|
| def get_chunks_of_file(self, file_node_id: str) -> List[ChunkNode]: |
| self.logger.debug(f"Getting chunks for file node: {file_node_id}") |
| """ |
| Get all ChunkNodes associated with a specific FileNode. |
| |
| Args: |
| file_node (FileNode): The file node to get chunks for. |
| |
| Returns: |
| List[ChunkNode]: A list of ChunkNodes associated with the file. |
| """ |
| chunk_nodes = [] |
| for node in self.graph.neighbors(file_node_id): |
| |
| edge_data = self.graph.get_edge_data(file_node_id, node) |
| node_data = self.graph.nodes[node]['data'] |
| if ( |
| isinstance(node_data, ChunkNode) |
| and node_data.node_type == 'chunk' |
| and edge_data is not None |
| and edge_data.get('relation') == 'contains' |
| ): |
| chunk_nodes.append(node_data) |
| return chunk_nodes |
|
|
| def find_path(self, source_id: str, target_id: str, max_depth: int = 5) -> dict: |
| """ |
| Find the shortest path between two nodes in the knowledge graph. |
| |
| Args: |
| source_id (str): The ID of the source node. |
| target_id (str): The ID of the target node. |
| max_depth (int): Maximum depth to search for a path. Defaults to 5. |
| |
| Returns: |
| dict: A dictionary containing path information or error message. |
| """ |
| self.logger.debug(f"Finding path from {source_id} to {target_id} with max_depth={max_depth}") |
| g = self.graph |
|
|
| if source_id not in g: |
| return {"error": f"Source node '{source_id}' not found."} |
| if target_id not in g: |
| return {"error": f"Target node '{target_id}' not found."} |
|
|
| try: |
| path = nx.shortest_path(g, source=source_id, target=target_id) |
|
|
| if len(path) - 1 > max_depth: |
| return { |
| "source_id": source_id, |
| "target_id": target_id, |
| "path": [], |
| "length": len(path) - 1, |
| "text": f"Path exists but exceeds max_depth of {max_depth} (actual length: {len(path) - 1})" |
| } |
|
|
| |
| path_details = [] |
| for i, node_id in enumerate(path): |
| node = g.nodes[node_id]['data'] |
| node_info = { |
| "node_id": node_id, |
| "name": getattr(node, 'name', 'Unknown'), |
| "type": getattr(node, 'node_type', 'Unknown'), |
| "step": i |
| } |
|
|
| |
| if i < len(path) - 1: |
| next_node_id = path[i + 1] |
| edge_data = g.get_edge_data(node_id, next_node_id) |
| node_info["edge_to_next"] = edge_data.get('relation', 'Unknown') if edge_data else 'Unknown' |
|
|
| path_details.append(node_info) |
|
|
| |
| text = f"Path from '{source_id}' to '{target_id}' (length: {len(path) - 1}):\n\n" |
| for i, node_info in enumerate(path_details): |
| text += f"{i}. {node_info['name']} ({node_info['type']})\n" |
| text += f" Node ID: {node_info['node_id']}\n" |
| if 'edge_to_next' in node_info: |
| text += f" --[{node_info['edge_to_next']}]--> \n" |
|
|
| return { |
| "source_id": source_id, |
| "target_id": target_id, |
| "path": path_details, |
| "length": len(path) - 1, |
| "text": text |
| } |
|
|
| except nx.NetworkXNoPath: |
| return { |
| "source_id": source_id, |
| "target_id": target_id, |
| "path": [], |
| "length": -1, |
| "text": f"No path found between '{source_id}' and '{target_id}'" |
| } |
| except Exception as e: |
| self.logger.error(f"Error finding path: {str(e)}") |
| return {"error": f"Error finding path: {str(e)}"} |
|
|
| def get_subgraph(self, node_id: str, depth: int = 2, edge_types: Optional[List[str]] = None) -> dict: |
| """ |
| Extract a subgraph around a node up to a specified depth. |
| |
| Args: |
| node_id (str): The ID of the central node. |
| depth (int): The depth/radius of the subgraph to extract. Defaults to 2. |
| edge_types (Optional[List[str]]): Optional list of edge types to include (e.g., ['calls', 'contains']). |
| |
| Returns: |
| dict: A dictionary containing subgraph information or error message. |
| """ |
| self.logger.debug(f"Getting subgraph for node {node_id} with depth={depth}, edge_types={edge_types}") |
| g = self.graph |
|
|
| if node_id not in g: |
| return {"error": f"Node '{node_id}' not found."} |
|
|
| |
| nodes_at_depth = {node_id} |
| all_nodes = {node_id} |
|
|
| for d in range(depth): |
| next_level = set() |
| for n in nodes_at_depth: |
| |
| for neighbor in g.successors(n): |
| if edge_types is None: |
| next_level.add(neighbor) |
| else: |
| edge_data = g.get_edge_data(n, neighbor) |
| if edge_data and edge_data.get('relation') in edge_types: |
| next_level.add(neighbor) |
|
|
| for neighbor in g.predecessors(n): |
| if edge_types is None: |
| next_level.add(neighbor) |
| else: |
| edge_data = g.get_edge_data(neighbor, n) |
| if edge_data and edge_data.get('relation') in edge_types: |
| next_level.add(neighbor) |
|
|
| nodes_at_depth = next_level - all_nodes |
| all_nodes.update(next_level) |
|
|
| |
| subgraph = g.subgraph(all_nodes).copy() |
|
|
| |
| nodes = [] |
| for n in subgraph.nodes(): |
| node = subgraph.nodes[n]['data'] |
| nodes.append({ |
| "node_id": n, |
| "name": getattr(node, 'name', 'Unknown'), |
| "type": getattr(node, 'node_type', 'Unknown') |
| }) |
|
|
| |
| edges = [] |
| for source, target, data in subgraph.edges(data=True): |
| edges.append({ |
| "source": source, |
| "target": target, |
| "relation": data.get('relation', 'Unknown') |
| }) |
|
|
| |
| text = f"Subgraph around '{node_id}' (depth: {depth}):\n" |
| if edge_types: |
| text += f"Edge types filter: {', '.join(edge_types)}\n" |
| text += f"\nNodes: {len(nodes)}\n" |
| text += f"Edges: {len(edges)}\n\n" |
|
|
| |
| nodes_by_type = {} |
| for node in nodes: |
| node_type = node['type'] |
| if node_type not in nodes_by_type: |
| nodes_by_type[node_type] = [] |
| nodes_by_type[node_type].append(node) |
|
|
| for node_type, type_nodes in nodes_by_type.items(): |
| text += f"{node_type} ({len(type_nodes)}):\n" |
| for node in type_nodes[:5]: |
| text += f" - {node['name']} ({node['node_id']})\n" |
| if len(type_nodes) > 5: |
| text += f" ... and {len(type_nodes) - 5} more\n" |
| text += "\n" |
|
|
| |
| edge_by_relation = {} |
| for edge in edges: |
| relation = edge['relation'] |
| edge_by_relation[relation] = edge_by_relation.get(relation, 0) + 1 |
|
|
| if edge_by_relation: |
| text += "Edge types:\n" |
| for relation, count in edge_by_relation.items(): |
| text += f" - {relation}: {count}\n" |
|
|
| return { |
| "center_node_id": node_id, |
| "depth": depth, |
| "edge_types_filter": edge_types, |
| "node_count": len(nodes), |
| "edge_count": len(edges), |
| "nodes": nodes, |
| "edges": edges, |
| "nodes_by_type": nodes_by_type, |
| "edge_by_relation": edge_by_relation, |
| "text": text |
| } |
|
|