slm-code-engine / backend /app /rag /retriever.py
vienoux's picture
Upload folder using huggingface_hub
f9adcbf verified
"""
Code retriever - High-level interface for RAG
Combines embedding and vector search to retrieve similar code examples
"""
import logging
from typing import List, Dict, Any, Optional
from pathlib import Path
from .embedder import CodeEmbedder
from .vector_store import VectorStore
from app.models.schemas import Language, TaskType
logger = logging.getLogger(__name__)
class CodeRetriever:
"""High-level interface for code example retrieval"""
def __init__(
self,
embedder: Optional[CodeEmbedder] = None,
vector_store: Optional[VectorStore] = None,
index_path: Optional[str] = None
):
"""
Initialize code retriever
Args:
embedder: CodeEmbedder instance (creates default if None)
vector_store: VectorStore instance (creates default if None)
index_path: Path to FAISS index file
"""
self.embedder = embedder or CodeEmbedder()
self.vector_store = vector_store or VectorStore(
embedding_dim=768, # CodeBERT dimension
index_path=index_path
)
self.initialized = False
def initialize(self):
"""Initialize embedder and vector store"""
if self.initialized:
return
try:
logger.info("Initializing CodeRetriever...")
# Initialize embedder
self.embedder.initialize()
# Initialize vector store
self.vector_store.initialize()
self.initialized = True
logger.info("CodeRetriever initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize CodeRetriever: {e}")
raise
def add_examples(
self,
codes: List[str],
languages: List[Language],
tasks: List[TaskType],
descriptions: Optional[List[str]] = None
):
"""
Add code examples to the index
Args:
codes: List of code snippets
languages: List of programming languages
tasks: List of task types
descriptions: Optional list of descriptions
"""
if not self.initialized:
self.initialize()
try:
logger.info(f"Adding {len(codes)} code examples to index")
# Generate embeddings
embeddings = self.embedder.embed_batch(codes)
# Prepare metadata
metadata = []
for i, (code, lang, task) in enumerate(zip(codes, languages, tasks)):
meta = {
"code": code,
"language": lang.value if hasattr(lang, 'value') else str(lang),
"task": task.value if hasattr(task, 'value') else str(task),
"description": descriptions[i] if descriptions and i < len(descriptions) else None
}
metadata.append(meta)
# Add to vector store
self.vector_store.add(embeddings, metadata)
logger.info(f"Successfully added {len(codes)} examples")
except Exception as e:
logger.error(f"Failed to add examples: {e}")
raise
def retrieve(
self,
query_code: str,
language: Optional[Language] = None,
task: Optional[TaskType] = None,
k: int = 3
) -> List[Dict[str, Any]]:
"""
Retrieve similar code examples
Args:
query_code: Code snippet to find similar examples for
language: Filter by programming language (optional)
task: Filter by task type (optional)
k: Number of examples to retrieve
Returns:
List of similar code examples with metadata
"""
if not self.initialized:
self.initialize()
try:
logger.debug(f"Retrieving {k} similar examples for query")
# Generate query embedding
query_embedding = self.embedder.embed(query_code)
# Search vector store (get more results for filtering)
search_k = k * 3 if (language or task) else k
results = self.vector_store.search(query_embedding, k=search_k)
# Filter by language/task if specified
filtered_results = []
for distance, metadata in results:
# Apply filters
if language and metadata.get("language") != (
language.value if hasattr(language, 'value') else str(language)
):
continue
if task and metadata.get("task") != (
task.value if hasattr(task, 'value') else str(task)
):
continue
filtered_results.append({
"code": metadata.get("code"),
"language": metadata.get("language"),
"task": metadata.get("task"),
"description": metadata.get("description"),
"similarity_score": 1.0 / (1.0 + distance) # Convert distance to similarity
})
if len(filtered_results) >= k:
break
logger.info(f"Retrieved {len(filtered_results)} similar examples")
return filtered_results
except Exception as e:
logger.error(f"Failed to retrieve examples: {e}")
return []
def save(self):
"""Save the vector store index"""
if self.initialized:
self.vector_store.save()
def clear(self):
"""Clear all indexed examples"""
if self.initialized:
self.vector_store.clear()
def build_context(
self,
query_code: str,
language: Optional[Language] = None,
task: Optional[TaskType] = None,
k: int = 3
) -> str:
"""
Build context string from retrieved examples
Args:
query_code: Code snippet to find similar examples for
language: Filter by programming language
task: Filter by task type
k: Number of examples to include
Returns:
Formatted context string for LLM prompts
"""
examples = self.retrieve(query_code, language, task, k)
if not examples:
return ""
context_parts = ["Here are similar code examples:\n"]
for i, example in enumerate(examples, 1):
context_parts.append(f"\nExample {i}:")
if example.get("description"):
context_parts.append(f"Description: {example['description']}")
context_parts.append(f"```{example.get('language', 'python')}")
context_parts.append(example.get("code", ""))
context_parts.append("```")
return "\n".join(context_parts)