slm-code-engine / backend /app /rag /embedder.py
vienoux's picture
Upload folder using huggingface_hub
f9adcbf verified
"""
Code embedder using sentence-transformers
Converts code snippets into vector embeddings for similarity search
"""
import logging
from typing import List, Optional
import numpy as np
logger = logging.getLogger(__name__)
class CodeEmbedder:
"""Generates embeddings for code using sentence-transformers"""
def __init__(self, model_name: str = "microsoft/codebert-base"):
"""
Initialize the code embedder
Args:
model_name: HuggingFace model for code embeddings
Default: microsoft/codebert-base (125M params, fast)
"""
self.model_name = model_name
self.model: Optional[object] = None
def initialize(self):
"""Load the embedding model (lazy loading)"""
if self.model is not None:
return
try:
from sentence_transformers import SentenceTransformer
logger.info(f"Loading embedding model: {self.model_name}")
self.model = SentenceTransformer(self.model_name)
logger.info("Embedding model loaded successfully")
except Exception as e:
logger.error(f"Failed to load embedding model: {e}")
raise
def embed(self, code: str) -> np.ndarray:
"""
Generate embedding for a single code snippet
Args:
code: Source code string
Returns:
Embedding vector as numpy array
"""
if self.model is None:
self.initialize()
try:
# Truncate very long code (max 512 tokens for CodeBERT)
if len(code) > 2000:
code = code[:2000]
embedding = self.model.encode(code, convert_to_numpy=True)
return embedding
except Exception as e:
logger.error(f"Failed to generate embedding: {e}")
raise
def embed_batch(self, codes: List[str]) -> np.ndarray:
"""
Generate embeddings for multiple code snippets
Args:
codes: List of source code strings
Returns:
Matrix of embeddings (n_samples x embedding_dim)
"""
if self.model is None:
self.initialize()
try:
# Truncate long codes
truncated_codes = [c[:2000] if len(c) > 2000 else c for c in codes]
embeddings = self.model.encode(
truncated_codes,
convert_to_numpy=True,
show_progress_bar=True
)
return embeddings
except Exception as e:
logger.error(f"Failed to generate batch embeddings: {e}")
raise