Spaces:
Sleeping
Sleeping
| """ | |
| Code retriever - High-level interface for RAG | |
| Combines embedding and vector search to retrieve similar code examples | |
| """ | |
| import logging | |
| from typing import List, Dict, Any, Optional | |
| from pathlib import Path | |
| from .embedder import CodeEmbedder | |
| from .vector_store import VectorStore | |
| from app.models.schemas import Language, TaskType | |
| logger = logging.getLogger(__name__) | |
| class CodeRetriever: | |
| """High-level interface for code example retrieval""" | |
| def __init__( | |
| self, | |
| embedder: Optional[CodeEmbedder] = None, | |
| vector_store: Optional[VectorStore] = None, | |
| index_path: Optional[str] = None | |
| ): | |
| """ | |
| Initialize code retriever | |
| Args: | |
| embedder: CodeEmbedder instance (creates default if None) | |
| vector_store: VectorStore instance (creates default if None) | |
| index_path: Path to FAISS index file | |
| """ | |
| self.embedder = embedder or CodeEmbedder() | |
| self.vector_store = vector_store or VectorStore( | |
| embedding_dim=768, # CodeBERT dimension | |
| index_path=index_path | |
| ) | |
| self.initialized = False | |
| def initialize(self): | |
| """Initialize embedder and vector store""" | |
| if self.initialized: | |
| return | |
| try: | |
| logger.info("Initializing CodeRetriever...") | |
| # Initialize embedder | |
| self.embedder.initialize() | |
| # Initialize vector store | |
| self.vector_store.initialize() | |
| self.initialized = True | |
| logger.info("CodeRetriever initialized successfully") | |
| except Exception as e: | |
| logger.error(f"Failed to initialize CodeRetriever: {e}") | |
| raise | |
| def add_examples( | |
| self, | |
| codes: List[str], | |
| languages: List[Language], | |
| tasks: List[TaskType], | |
| descriptions: Optional[List[str]] = None | |
| ): | |
| """ | |
| Add code examples to the index | |
| Args: | |
| codes: List of code snippets | |
| languages: List of programming languages | |
| tasks: List of task types | |
| descriptions: Optional list of descriptions | |
| """ | |
| if not self.initialized: | |
| self.initialize() | |
| try: | |
| logger.info(f"Adding {len(codes)} code examples to index") | |
| # Generate embeddings | |
| embeddings = self.embedder.embed_batch(codes) | |
| # Prepare metadata | |
| metadata = [] | |
| for i, (code, lang, task) in enumerate(zip(codes, languages, tasks)): | |
| meta = { | |
| "code": code, | |
| "language": lang.value if hasattr(lang, 'value') else str(lang), | |
| "task": task.value if hasattr(task, 'value') else str(task), | |
| "description": descriptions[i] if descriptions and i < len(descriptions) else None | |
| } | |
| metadata.append(meta) | |
| # Add to vector store | |
| self.vector_store.add(embeddings, metadata) | |
| logger.info(f"Successfully added {len(codes)} examples") | |
| except Exception as e: | |
| logger.error(f"Failed to add examples: {e}") | |
| raise | |
| def retrieve( | |
| self, | |
| query_code: str, | |
| language: Optional[Language] = None, | |
| task: Optional[TaskType] = None, | |
| k: int = 3 | |
| ) -> List[Dict[str, Any]]: | |
| """ | |
| Retrieve similar code examples | |
| Args: | |
| query_code: Code snippet to find similar examples for | |
| language: Filter by programming language (optional) | |
| task: Filter by task type (optional) | |
| k: Number of examples to retrieve | |
| Returns: | |
| List of similar code examples with metadata | |
| """ | |
| if not self.initialized: | |
| self.initialize() | |
| try: | |
| logger.debug(f"Retrieving {k} similar examples for query") | |
| # Generate query embedding | |
| query_embedding = self.embedder.embed(query_code) | |
| # Search vector store (get more results for filtering) | |
| search_k = k * 3 if (language or task) else k | |
| results = self.vector_store.search(query_embedding, k=search_k) | |
| # Filter by language/task if specified | |
| filtered_results = [] | |
| for distance, metadata in results: | |
| # Apply filters | |
| if language and metadata.get("language") != ( | |
| language.value if hasattr(language, 'value') else str(language) | |
| ): | |
| continue | |
| if task and metadata.get("task") != ( | |
| task.value if hasattr(task, 'value') else str(task) | |
| ): | |
| continue | |
| filtered_results.append({ | |
| "code": metadata.get("code"), | |
| "language": metadata.get("language"), | |
| "task": metadata.get("task"), | |
| "description": metadata.get("description"), | |
| "similarity_score": 1.0 / (1.0 + distance) # Convert distance to similarity | |
| }) | |
| if len(filtered_results) >= k: | |
| break | |
| logger.info(f"Retrieved {len(filtered_results)} similar examples") | |
| return filtered_results | |
| except Exception as e: | |
| logger.error(f"Failed to retrieve examples: {e}") | |
| return [] | |
| def save(self): | |
| """Save the vector store index""" | |
| if self.initialized: | |
| self.vector_store.save() | |
| def clear(self): | |
| """Clear all indexed examples""" | |
| if self.initialized: | |
| self.vector_store.clear() | |
| def build_context( | |
| self, | |
| query_code: str, | |
| language: Optional[Language] = None, | |
| task: Optional[TaskType] = None, | |
| k: int = 3 | |
| ) -> str: | |
| """ | |
| Build context string from retrieved examples | |
| Args: | |
| query_code: Code snippet to find similar examples for | |
| language: Filter by programming language | |
| task: Filter by task type | |
| k: Number of examples to include | |
| Returns: | |
| Formatted context string for LLM prompts | |
| """ | |
| examples = self.retrieve(query_code, language, task, k) | |
| if not examples: | |
| return "" | |
| context_parts = ["Here are similar code examples:\n"] | |
| for i, example in enumerate(examples, 1): | |
| context_parts.append(f"\nExample {i}:") | |
| if example.get("description"): | |
| context_parts.append(f"Description: {example['description']}") | |
| context_parts.append(f"```{example.get('language', 'python')}") | |
| context_parts.append(example.get("code", "")) | |
| context_parts.append("```") | |
| return "\n".join(context_parts) | |