| | import gradio as gr |
| | import torch |
| | import torch.nn.functional as F |
| | from transformers import AutoTokenizer, AutoModel |
| | import pandas as pd |
| | import sys |
| | import os |
| | import shutil |
| | from pathlib import Path |
| | import chromadb |
| | from chromadb.config import Settings |
| | import uuid |
| | import tempfile |
| |
|
| | |
| | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) |
| | from scripts.core.ingestion.ingest import GitCrawler |
| | from scripts.core.ingestion.chunk import RepoChunker |
| |
|
| | |
| | BASELINE_MODEL = "microsoft/codebert-base" |
| | FINETUNED_MODEL = "shubharuidas/codebert-base-code-embed-mrl-langchain-langgraph" |
| | DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
| | DB_DIR = Path(os.path.abspath("data/chroma_db_comparison")) |
| | DB_DIR.mkdir(parents=True, exist_ok=True) |
| |
|
| | print(f"Loading models on {DEVICE}...") |
| | print("1. Loading baseline model...") |
| | baseline_tokenizer = AutoTokenizer.from_pretrained(BASELINE_MODEL) |
| | baseline_model = AutoModel.from_pretrained(BASELINE_MODEL) |
| | baseline_model.to(DEVICE) |
| | baseline_model.eval() |
| |
|
| | print("2. Loading fine-tuned model...") |
| | finetuned_tokenizer = AutoTokenizer.from_pretrained(FINETUNED_MODEL) |
| | finetuned_model = AutoModel.from_pretrained(FINETUNED_MODEL) |
| | finetuned_model.to(DEVICE) |
| | finetuned_model.eval() |
| | print("Both models loaded!") |
| |
|
| | |
| | chroma_client = chromadb.PersistentClient(path=str(DB_DIR)) |
| | baseline_collection = chroma_client.get_or_create_collection(name="baseline_rag", metadata={"hnsw:space": "cosine"}) |
| | finetuned_collection = chroma_client.get_or_create_collection(name="finetuned_rag", metadata={"hnsw:space": "cosine"}) |
| |
|
| | |
| | def compute_baseline_embeddings(text_list): |
| | if not text_list: return None |
| | inputs = baseline_tokenizer(text_list, return_tensors="pt", padding=True, truncation=True, max_length=512).to(DEVICE) |
| | with torch.no_grad(): |
| | out = baseline_model(**inputs) |
| | emb = out.last_hidden_state.mean(dim=1) |
| | return F.normalize(emb, p=2, dim=1) |
| |
|
| | def compute_finetuned_embeddings(text_list): |
| | if not text_list: return None |
| | inputs = finetuned_tokenizer(text_list, return_tensors="pt", padding=True, truncation=True, max_length=512).to(DEVICE) |
| | with torch.no_grad(): |
| | out = finetuned_model(**inputs) |
| | emb = out.last_hidden_state.mean(dim=1) |
| | return F.normalize(emb, p=2, dim=1) |
| |
|
| | |
| | def reset_baseline(): |
| | chroma_client.delete_collection("baseline_rag") |
| | global baseline_collection |
| | baseline_collection = chroma_client.get_or_create_collection(name="baseline_rag", metadata={"hnsw:space": "cosine"}) |
| | return "Baseline database reset." |
| |
|
| | def reset_finetuned(): |
| | chroma_client.delete_collection("finetuned_rag") |
| | global finetuned_collection |
| | finetuned_collection = chroma_client.get_or_create_collection(name="finetuned_rag", metadata={"hnsw:space": "cosine"}) |
| | return "Fine-tuned database reset." |
| |
|
| | |
| | def list_baseline_files(): |
| | count = baseline_collection.count() |
| | if count == 0: |
| | return [["No data indexed yet", "-", "-"]] |
| | |
| | try: |
| | data = baseline_collection.get(limit=min(count, 1000), include=["metadatas"]) |
| | file_stats = {} |
| | for meta in data['metadatas']: |
| | fname = meta.get("file_name", "unknown") |
| | url = meta.get("url", "unknown") |
| | if fname not in file_stats: |
| | file_stats[fname] = {"count": 0, "url": url} |
| | file_stats[fname]["count"] += 1 |
| | |
| | results = [[fname, stats["count"], stats["url"]] for fname, stats in file_stats.items()] |
| | return sorted(results, key=lambda x: x[1], reverse=True) |
| | except Exception as e: |
| | return [[f"Error: {str(e)}", "-", "-"]] |
| |
|
| | def list_finetuned_files(): |
| | count = finetuned_collection.count() |
| | if count == 0: |
| | return [["No data indexed yet", "-", "-"]] |
| | |
| | try: |
| | data = finetuned_collection.get(limit=min(count, 1000), include=["metadatas"]) |
| | file_stats = {} |
| | for meta in data['metadatas']: |
| | fname = meta.get("file_name", "unknown") |
| | url = meta.get("url", "unknown") |
| | if fname not in file_stats: |
| | file_stats[fname] = {"count": 0, "url": url} |
| | file_stats[fname]["count"] += 1 |
| | |
| | results = [[fname, stats["count"], stats["url"]] for fname, stats in file_stats.items()] |
| | return sorted(results, key=lambda x: x[1], reverse=True) |
| | except Exception as e: |
| | return [[f"Error: {str(e)}", "-", "-"]] |
| |
|
| | |
| | def get_files_list_baseline(): |
| | """Get list of unique files in baseline collection""" |
| | try: |
| | data = baseline_collection.get(include=["metadatas"]) |
| | if not data['metadatas']: |
| | return [] |
| | files = list(set([m.get("file_name", "unknown") for m in data['metadatas']])) |
| | return sorted(files) |
| | except Exception as e: |
| | return [] |
| |
|
| | def get_files_list_finetuned(): |
| | """Get list of unique files in fine-tuned collection""" |
| | try: |
| | data = finetuned_collection.get(include=["metadatas"]) |
| | if not data['metadatas']: |
| | return [] |
| | files = list(set([m.get("file_name", "unknown") for m in data['metadatas']])) |
| | return sorted(files) |
| | except Exception as e: |
| | return [] |
| |
|
| | def get_chunks_for_file_baseline(file_name): |
| | """Get all chunks for a specific file from baseline collection""" |
| | if not file_name: |
| | return {"error": "No file selected"} |
| | |
| | try: |
| | |
| | data = baseline_collection.get( |
| | include=["documents", "metadatas", "embeddings"] |
| | ) |
| | |
| | if not data['documents']: |
| | return {"error": "No chunks found"} |
| | |
| | |
| | chunks = [] |
| | for i, (doc, meta, emb) in enumerate(zip(data['documents'], data['metadatas'], data['embeddings'])): |
| | if meta.get("file_name") == file_name: |
| | chunks.append({ |
| | "chunk_id": len(chunks) + 1, |
| | "content": doc[:500] + "..." if len(doc) > 500 else doc, |
| | "full_length": len(doc), |
| | "metadata": meta, |
| | "embedding_dim": len(emb) if emb is not None else 0 |
| | }) |
| | |
| | if not chunks: |
| | return {"error": "No chunks found for this file"} |
| | |
| | return { |
| | "file_name": file_name, |
| | "total_chunks": len(chunks), |
| | "chunks": chunks |
| | } |
| | except Exception as e: |
| | import traceback |
| | error_details = traceback.format_exc() |
| | print(f"ERROR in get_chunks_for_file_baseline: {error_details}") |
| | return {"error": str(e)} |
| |
|
| | def get_chunks_for_file_finetuned(file_name): |
| | """Get all chunks for a specific file from fine-tuned collection""" |
| | if not file_name: |
| | return {"error": "No file selected"} |
| | |
| | try: |
| | |
| | data = finetuned_collection.get( |
| | include=["documents", "metadatas", "embeddings"] |
| | ) |
| | |
| | if not data['documents']: |
| | return {"error": "No chunks found"} |
| | |
| | |
| | chunks = [] |
| | for i, (doc, meta, emb) in enumerate(zip(data['documents'], data['metadatas'], data['embeddings'])): |
| | if meta.get("file_name") == file_name: |
| | chunks.append({ |
| | "chunk_id": len(chunks) + 1, |
| | "content": doc[:500] + "..." if len(doc) > 500 else doc, |
| | "full_length": len(doc), |
| | "metadata": meta, |
| | "embedding_dim": len(emb) if emb is not None else 0 |
| | }) |
| | |
| | if not chunks: |
| | return {"error": "No chunks found for this file"} |
| | |
| | return { |
| | "file_name": file_name, |
| | "total_chunks": len(chunks), |
| | "chunks": chunks |
| | } |
| | except Exception as e: |
| | return {"error": str(e)} |
| |
|
| | def download_chunks_baseline(file_name): |
| | """Export chunks to JSON file for baseline""" |
| | if not file_name: |
| | return None |
| | |
| | import json |
| | import tempfile |
| | |
| | chunks_data = get_chunks_for_file_baseline(file_name) |
| | |
| | temp_file = tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.json') |
| | json.dump(chunks_data, temp_file, indent=2) |
| | temp_file.close() |
| | |
| | return temp_file.name |
| |
|
| | def download_chunks_finetuned(file_name): |
| | """Export chunks to JSON file for fine-tuned""" |
| | if not file_name: |
| | return None |
| | |
| | import json |
| | import tempfile |
| | |
| | chunks_data = get_chunks_for_file_finetuned(file_name) |
| | |
| | temp_file = tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.json') |
| | json.dump(chunks_data, temp_file, indent=2) |
| | temp_file.close() |
| | |
| | return temp_file.name |
| |
|
| | |
| | def search_baseline(query, top_k=5): |
| | if baseline_collection.count() == 0: return [] |
| | query_emb = compute_baseline_embeddings([query]) |
| | if query_emb is None: return [] |
| | query_vec = query_emb.cpu().numpy().tolist()[0] |
| | results = baseline_collection.query(query_embeddings=[query_vec], n_results=min(top_k, baseline_collection.count()), include=["metadatas", "documents", "distances"]) |
| | output = [] |
| | if results['ids']: |
| | for i in range(len(results['ids'][0])): |
| | meta = results['metadatas'][0][i] |
| | code = results['documents'][0][i] |
| | dist = results['distances'][0][i] |
| | score = 1 - dist |
| | output.append([meta.get("file_name", "unknown"), f"{score:.4f}", code[:300] + "..."]) |
| | return output |
| |
|
| | def search_finetuned(query, top_k=5): |
| | if finetuned_collection.count() == 0: return [] |
| | query_emb = compute_finetuned_embeddings([query]) |
| | if query_emb is None: return [] |
| | query_vec = query_emb.cpu().numpy().tolist()[0] |
| | results = finetuned_collection.query(query_embeddings=[query_vec], n_results=min(top_k, finetuned_collection.count()), include=["metadatas", "documents", "distances"]) |
| | output = [] |
| | if results['ids']: |
| | for i in range(len(results['ids'][0])): |
| | meta = results['metadatas'][0][i] |
| | code = results['documents'][0][i] |
| | dist = results['distances'][0][i] |
| | score = 1 - dist |
| | output.append([meta.get("file_name", "unknown"), f"{score:.4f}", code[:300] + "..."]) |
| | return output |
| |
|
| | def search_comparison(query, top_k=5): |
| | baseline_results = search_baseline(query, top_k) |
| | finetuned_results = search_finetuned(query, top_k) |
| | return baseline_results, finetuned_results |
| |
|
| | |
| | def ingest_from_url(repo_url): |
| | if not repo_url.startswith("http"): |
| | yield "Invalid URL" |
| | return |
| | |
| | DATA_DIR = Path(os.path.abspath("data/raw_ingest")) |
| | import stat |
| | def remove_readonly(func, path, _): |
| | os.chmod(path, stat.S_IWRITE) |
| | func(path) |
| | |
| | try: |
| | if DATA_DIR.exists(): |
| | shutil.rmtree(DATA_DIR, onerror=remove_readonly) |
| | |
| | yield f"Cloning {repo_url}..." |
| | crawler = GitCrawler(cache_dir=DATA_DIR) |
| | repo_path = crawler.clone_repository(repo_url) |
| | if not repo_path: |
| | yield "Failed to clone repository." |
| | return |
| | |
| | yield "Listing files..." |
| | files = crawler.list_files(repo_path, extensions={'.py', '.md', '.json', '.js', '.ts', '.java', '.cpp'}) |
| | if isinstance(files, tuple): files = [f.path for f in files[0]] |
| | |
| | total_files = len(files) |
| | yield f"Found {total_files} files. Chunking..." |
| | |
| | chunker = RepoChunker() |
| | all_chunks = [] |
| | |
| | for i, file_path in enumerate(files): |
| | yield f"Chunking: {i+1}/{total_files} ({file_path.name})" |
| | try: |
| | meta = {"file_name": file_path.name, "url": repo_url} |
| | file_chunks = chunker.chunk_file(file_path, repo_metadata=meta) |
| | all_chunks.extend(file_chunks) |
| | except Exception as e: |
| | print(f"Skipping {file_path}: {e}") |
| | |
| | if not all_chunks: |
| | yield "No valid chunks found." |
| | return |
| | |
| | total_chunks = len(all_chunks) |
| | yield f"Generated {total_chunks} chunks. Embedding (BASELINE)..." |
| | |
| | batch_size = 64 |
| | |
| | for i in range(0, total_chunks, batch_size): |
| | batch = all_chunks[i:i+batch_size] |
| | texts = [c.code for c in batch] |
| | ids = [str(uuid.uuid4()) for _ in batch] |
| | metadatas = [{"file_name": Path(c.file_path).name, "url": repo_url} for c in batch] |
| | |
| | embeddings = compute_baseline_embeddings(texts) |
| | if embeddings is not None: |
| | baseline_collection.add(ids=ids, embeddings=embeddings.cpu().numpy().tolist(), metadatas=metadatas, documents=texts) |
| | yield f"Baseline: {min(i+batch_size, total_chunks)}/{total_chunks}" |
| | |
| | yield f"Embedding (FINE-TUNED)..." |
| | |
| | for i in range(0, total_chunks, batch_size): |
| | batch = all_chunks[i:i+batch_size] |
| | texts = [c.code for c in batch] |
| | ids = [str(uuid.uuid4()) for _ in batch] |
| | metadatas = [{"file_name": Path(c.file_path).name, "url": repo_url} for c in batch] |
| | |
| | embeddings = compute_finetuned_embeddings(texts) |
| | if embeddings is not None: |
| | finetuned_collection.add(ids=ids, embeddings=embeddings.cpu().numpy().tolist(), metadatas=metadatas, documents=texts) |
| | yield f"Fine-tuned: {min(i+batch_size, total_chunks)}/{total_chunks}" |
| | |
| | yield f"SUCCESS! Indexed {total_chunks} chunks in both databases." |
| | except Exception as e: |
| | import traceback |
| | traceback.print_exc() |
| | yield f"Error: {str(e)}" |
| |
|
| | def ingest_from_files(files): |
| | if not files or len(files) == 0: |
| | yield "No files uploaded." |
| | return |
| | |
| | try: |
| | yield f"Processing {len(files)} file(s)..." |
| | |
| | chunker = RepoChunker() |
| | all_chunks = [] |
| | |
| | for i, file in enumerate(files): |
| | yield f"Chunking file {i+1}/{len(files)}: {Path(file.name).name}" |
| | try: |
| | |
| | file_path = Path(file.name) |
| | meta = {"file_name": file_path.name, "url": "uploaded"} |
| | file_chunks = chunker.chunk_file(file_path, repo_metadata=meta) |
| | all_chunks.extend(file_chunks) |
| | except Exception as e: |
| | yield f"Error chunking {Path(file.name).name}: {str(e)}" |
| | import traceback |
| | traceback.print_exc() |
| |
|
| | |
| | if not all_chunks: |
| | yield "No valid chunks found." |
| | return |
| | |
| | total_chunks = len(all_chunks) |
| | yield f"Generated {total_chunks} chunks. Embedding (BASELINE)..." |
| | |
| | batch_size = 64 |
| | for i in range(0, total_chunks, batch_size): |
| | batch = all_chunks[i:i+batch_size] |
| | texts = [c.code for c in batch] |
| | ids = [str(uuid.uuid4()) for _ in batch] |
| | metadatas = [{"file_name": Path(c.file_path).name, "url": "uploaded"} for c in batch] |
| | |
| | embeddings = compute_baseline_embeddings(texts) |
| | if embeddings is not None: |
| | baseline_collection.add(ids=ids, embeddings=embeddings.cpu().numpy().tolist(), metadatas=metadatas, documents=texts) |
| | yield f"Baseline: {min(i+batch_size, total_chunks)}/{total_chunks}" |
| | |
| | yield f"Embedding (FINE-TUNED)..." |
| | for i in range(0, total_chunks, batch_size): |
| | batch = all_chunks[i:i+batch_size] |
| | texts = [c.code for c in batch] |
| | ids = [str(uuid.uuid4()) for _ in batch] |
| | metadatas = [{"file_name": Path(c.file_path).name, "url": "uploaded"} for c in batch] |
| | |
| | embeddings = compute_finetuned_embeddings(texts) |
| | if embeddings is not None: |
| | finetuned_collection.add(ids=ids, embeddings=embeddings.cpu().numpy().tolist(), metadatas=metadatas, documents=texts) |
| | yield f"Fine-tuned: {min(i+batch_size, total_chunks)}/{total_chunks}" |
| | |
| | yield f"SUCCESS! Indexed {total_chunks} chunks from uploaded files." |
| | except Exception as e: |
| | import traceback |
| | traceback.print_exc() |
| | yield f"Error: {str(e)}" |
| |
|
| | |
| | def analyze_embeddings_baseline(): |
| | count = baseline_collection.count() |
| | if count < 5: |
| | return "Not enough data (Need > 5 chunks).", None |
| | |
| | try: |
| | limit = min(count, 2000) |
| | data = baseline_collection.get(limit=limit, include=["embeddings", "metadatas"]) |
| | |
| | X = torch.tensor(data['embeddings']) |
| | X_mean = torch.mean(X, 0) |
| | X_centered = X - X_mean |
| | U, S, V = torch.pca_lowrank(X_centered, q=2) |
| | projected = torch.matmul(X_centered, V[:, :2]).numpy() |
| | |
| | indices = torch.randint(0, len(X), (min(100, len(X)),)) |
| | sample = X[indices] |
| | sim_matrix = torch.mm(sample, sample.t()) |
| | mask = ~torch.eye(len(sample), dtype=bool) |
| | avg_sim = sim_matrix[mask].mean().item() |
| | diversity_score = 1.0 - avg_sim |
| | |
| | metrics = ( |
| | f"BASELINE MODEL\n" |
| | f"Total Chunks: {count}\n" |
| | f"Analyzed: {len(X)}\n" |
| | f"Diversity Score: {diversity_score:.4f}\n" |
| | f"Avg Similarity: {avg_sim:.4f}" |
| | ) |
| | |
| | plot_df = pd.DataFrame({ |
| | "x": projected[:, 0], |
| | "y": projected[:, 1], |
| | "topic": [m.get("file_name", "unknown") for m in data['metadatas']] |
| | }) |
| | |
| | import matplotlib.pyplot as plt |
| | import io |
| | from PIL import Image |
| | |
| | |
| | fig, ax = plt.subplots(figsize=(10, 8)) |
| | fig.subplots_adjust(top=0.92) |
| | |
| | |
| | unique_topics = plot_df["topic"].unique() |
| | for topic in unique_topics: |
| | mask = plot_df["topic"] == topic |
| | ax.scatter(plot_df[mask]["x"], plot_df[mask]["y"], label=topic, alpha=0.6, s=50) |
| | |
| | ax.set_xlabel("PC1") |
| | ax.set_ylabel("PC2") |
| | ax.set_title("Baseline Semantic Space (PCA)", fontsize=14, pad=20) |
| | ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=8) |
| | ax.grid(True, alpha=0.3) |
| | plt.tight_layout() |
| | |
| | |
| | buf = io.BytesIO() |
| | plt.savefig(buf, format='png', dpi=100, bbox_inches='tight') |
| | buf.seek(0) |
| | img = Image.open(buf) |
| | plt.close() |
| | |
| | return metrics, img |
| | except Exception as e: |
| | import traceback |
| | traceback.print_exc() |
| | return f"Error: {e}", None |
| |
|
| | def analyze_embeddings_finetuned(): |
| | count = finetuned_collection.count() |
| | if count < 5: |
| | return "Not enough data (Need > 5 chunks).", None |
| | |
| | try: |
| | limit = min(count, 2000) |
| | data = finetuned_collection.get(limit=limit, include=["embeddings", "metadatas"]) |
| | |
| | X = torch.tensor(data['embeddings']) |
| | X_mean = torch.mean(X, 0) |
| | X_centered = X - X_mean |
| | U, S, V = torch.pca_lowrank(X_centered, q=2) |
| | projected = torch.matmul(X_centered, V[:, :2]).numpy() |
| | |
| | indices = torch.randint(0, len(X), (min(100, len(X)),)) |
| | sample = X[indices] |
| | sim_matrix = torch.mm(sample, sample.t()) |
| | mask = ~torch.eye(len(sample), dtype=bool) |
| | avg_sim = sim_matrix[mask].mean().item() |
| | diversity_score = 1.0 - avg_sim |
| | |
| | metrics = ( |
| | f"FINE-TUNED MODEL\n" |
| | f"Total Chunks: {count}\n" |
| | f"Analyzed: {len(X)}\n" |
| | f"Diversity Score: {diversity_score:.4f}\n" |
| | f"Avg Similarity: {avg_sim:.4f}" |
| | ) |
| | |
| | plot_df = pd.DataFrame({ |
| | "x": projected[:, 0], |
| | "y": projected[:, 1], |
| | "topic": [m.get("file_name", "unknown") for m in data['metadatas']] |
| | }) |
| | |
| | import matplotlib.pyplot as plt |
| | import io |
| | from PIL import Image |
| | |
| | |
| | fig, ax = plt.subplots(figsize=(10, 8)) |
| | fig.subplots_adjust(top=0.92) |
| | |
| | |
| | unique_topics = plot_df["topic"].unique() |
| | for topic in unique_topics: |
| | mask = plot_df["topic"] == topic |
| | ax.scatter(plot_df[mask]["x"], plot_df[mask]["y"], label=topic, alpha=0.6, s=50) |
| | |
| | ax.set_xlabel("PC1") |
| | ax.set_ylabel("PC2") |
| | ax.set_title("Fine-tuned Semantic Space (PCA)", fontsize=14, pad=20) |
| | ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=8) |
| | ax.grid(True, alpha=0.3) |
| | plt.tight_layout() |
| | |
| | |
| | buf = io.BytesIO() |
| | plt.savefig(buf, format='png', dpi=100, bbox_inches='tight') |
| | buf.seek(0) |
| | img = Image.open(buf) |
| | plt.close() |
| | |
| | return metrics, img |
| | except Exception as e: |
| | import traceback |
| | traceback.print_exc() |
| | return f"Error: {e}", None |
| |
|
| | def evaluate_retrieval_baseline(sample_limit): |
| | count = baseline_collection.count() |
| | if count < 10: return "Not enough data for evaluation (Need > 10 chunks)." |
| | |
| | try: |
| | fetch_limit = min(count, 2000) |
| | data = baseline_collection.get(limit=fetch_limit, include=["documents"]) |
| | |
| | import random |
| | actual_sample_size = min(sample_limit, len(data['ids'])) |
| | sample_indices = random.sample(range(len(data['ids'])), actual_sample_size) |
| | |
| | hits_at_1 = 0 |
| | hits_at_5 = 0 |
| | mrr_sum = 0 |
| | |
| | yield f"BASELINE: Evaluating {actual_sample_size} chunks..." |
| | |
| | for i, idx in enumerate(sample_indices): |
| | target_id = data['ids'][idx] |
| | code = data['documents'][idx] |
| | query = "\n".join(code.split("\n")[:3]) |
| | query_emb = compute_baseline_embeddings([query]).cpu().numpy().tolist()[0] |
| | results = baseline_collection.query(query_embeddings=[query_emb], n_results=10) |
| | found_ids = results['ids'][0] |
| | if target_id in found_ids: |
| | rank = found_ids.index(target_id) + 1 |
| | mrr_sum += 1.0 / rank |
| | if rank == 1: hits_at_1 += 1 |
| | if rank <= 5: hits_at_5 += 1 |
| | if i % 10 == 0: |
| | yield f"Baseline: {i}/{actual_sample_size}..." |
| | |
| | recall_1 = hits_at_1 / actual_sample_size |
| | recall_5 = hits_at_5 / actual_sample_size |
| | mrr = mrr_sum / actual_sample_size |
| | |
| | report = ( |
| | f"BASELINE EVALUATION ({actual_sample_size} chunks)\n" |
| | f"{'='*40}\n" |
| | f"Recall@1: {recall_1:.4f}\n" |
| | f"Recall@5: {recall_5:.4f}\n" |
| | f"MRR: {mrr:.4f}" |
| | ) |
| | yield report |
| | except Exception as e: |
| | import traceback |
| | traceback.print_exc() |
| | yield f"Error: {e}" |
| |
|
| | def evaluate_retrieval_finetuned(sample_limit): |
| | count = finetuned_collection.count() |
| | if count < 10: return "Not enough data for evaluation (Need > 10 chunks)." |
| | |
| | try: |
| | fetch_limit = min(count, 2000) |
| | data = finetuned_collection.get(limit=fetch_limit, include=["documents"]) |
| | |
| | import random |
| | actual_sample_size = min(sample_limit, len(data['ids'])) |
| | sample_indices = random.sample(range(len(data['ids'])), actual_sample_size) |
| | |
| | hits_at_1 = 0 |
| | hits_at_5 = 0 |
| | mrr_sum = 0 |
| | |
| | yield f"FINE-TUNED: Evaluating {actual_sample_size} chunks..." |
| | |
| | for i, idx in enumerate(sample_indices): |
| | target_id = data['ids'][idx] |
| | code = data['documents'][idx] |
| | query = "\n".join(code.split("\n")[:3]) |
| | query_emb = compute_finetuned_embeddings([query]).cpu().numpy().tolist()[0] |
| | results = finetuned_collection.query(query_embeddings=[query_emb], n_results=10) |
| | found_ids = results['ids'][0] |
| | if target_id in found_ids: |
| | rank = found_ids.index(target_id) + 1 |
| | mrr_sum += 1.0 / rank |
| | if rank == 1: hits_at_1 += 1 |
| | if rank <= 5: hits_at_5 += 1 |
| | if i % 10 == 0: |
| | yield f"Fine-tuned: {i}/{actual_sample_size}..." |
| | |
| | recall_1 = hits_at_1 / actual_sample_size |
| | recall_5 = hits_at_5 / actual_sample_size |
| | mrr = mrr_sum / actual_sample_size |
| | |
| | report = ( |
| | f"FINE-TUNED EVALUATION ({actual_sample_size} chunks)\n" |
| | f"{'='*40}\n" |
| | f"Recall@1: {recall_1:.4f}\n" |
| | f"Recall@5: {recall_5:.4f}\n" |
| | f"MRR: {mrr:.4f}" |
| | ) |
| | yield report |
| | except Exception as e: |
| | import traceback |
| | traceback.print_exc() |
| | yield f"Error: {e}" |
| |
|
| | |
| | theme = gr.themes.Soft(primary_hue="slate", neutral_hue="slate", spacing_size="sm", radius_size="md").set(body_background_fill="*neutral_50", block_background_fill="white", block_border_width="1px", block_title_text_weight="600") |
| |
|
| | css = """ |
| | h1 { text-align: center; font-family: 'Inter', sans-serif; margin-bottom: 1rem; color: #1e293b; } |
| | .gradio-container { max-width: 1400px !important; margin: auto; } |
| | .comparison-header { font-size: 1.1rem; font-weight: 600; color: #334155; text-align: center; padding: 0.5rem; } |
| | """ |
| |
|
| | with gr.Blocks(theme=theme, css=css, title="CodeMode - Baseline vs Fine-tuned") as demo: |
| | gr.Markdown("# CodeMode: Baseline vs Fine-tuned Model Comparison") |
| | gr.Markdown("Compare retrieval performance between **microsoft/codebert-base** (baseline) and **MRL-enhanced fine-tuned** model") |
| | |
| | with gr.Tabs(): |
| | |
| | with gr.Tab("1. Ingest Code"): |
| | with gr.Tabs(): |
| | with gr.Tab("GitHub Repository"): |
| | repo_input = gr.Textbox(label="GitHub URL", placeholder="https://github.com/pallets/flask") |
| | ingest_url_btn = gr.Button("Ingest from URL", variant="primary") |
| | url_status = gr.Textbox(label="Status") |
| | ingest_url_btn.click(ingest_from_url, inputs=repo_input, outputs=url_status) |
| | |
| | with gr.Tab("Upload Python Files"): |
| | file_upload = gr.File(label="Upload .py files", file_types=[".py"], file_count="multiple") |
| | ingest_files_btn = gr.Button("Ingest Uploaded Files", variant="primary") |
| | upload_status = gr.Textbox(label="Status") |
| | ingest_files_btn.click(ingest_from_files, inputs=file_upload, outputs=upload_status) |
| | |
| | with gr.Row(): |
| | reset_baseline_btn = gr.Button("Reset Baseline DB", variant="stop") |
| | reset_finetuned_btn = gr.Button("Reset Fine-tuned DB", variant="stop") |
| | reset_status = gr.Textbox(label="Reset Status") |
| | |
| | reset_baseline_btn.click(reset_baseline, inputs=[], outputs=reset_status) |
| | reset_finetuned_btn.click(reset_finetuned, inputs=[], outputs=reset_status) |
| | |
| | gr.Markdown("---") |
| | gr.Markdown("### Database Inspector") |
| | gr.Markdown("View indexed files in each collection") |
| | |
| | with gr.Row(): |
| | with gr.Column(): |
| | gr.Markdown("#### Baseline Collection") |
| | inspect_baseline_btn = gr.Button("Inspect Baseline DB", variant="secondary") |
| | baseline_files_df = gr.Dataframe( |
| | headers=["File Name", "Chunks", "Source URL"], |
| | datatype=["str", "number", "str"], |
| | interactive=False, |
| | value=[["No data yet", "-", "-"]] |
| | ) |
| | inspect_baseline_btn.click(list_baseline_files, inputs=[], outputs=baseline_files_df) |
| | |
| | with gr.Column(): |
| | gr.Markdown("#### Fine-tuned Collection") |
| | inspect_finetuned_btn = gr.Button("Inspect Fine-tuned DB", variant="secondary") |
| | finetuned_files_df = gr.Dataframe( |
| | headers=["File Name", "Chunks", "Source URL"], |
| | datatype=["str", "number", "str"], |
| | interactive=False, |
| | value=[["No data yet", "-", "-"]] |
| | ) |
| | inspect_finetuned_btn.click(list_finetuned_files, inputs=[], outputs=finetuned_files_df) |
| | |
| | gr.Markdown("---") |
| | gr.Markdown("### Chunk Inspector") |
| | gr.Markdown("View detailed chunk information for indexed files (content, metadata, schema)") |
| | |
| | with gr.Row(): |
| | with gr.Column(): |
| | gr.Markdown("#### Baseline Collection") |
| | baseline_file_dropdown = gr.Dropdown( |
| | label="Select File to Inspect", |
| | choices=[], |
| | interactive=True |
| | ) |
| | baseline_refresh_files = gr.Button("Refresh File List", variant="secondary") |
| | baseline_chunks_display = gr.JSON(label="Chunk Details") |
| | baseline_download_btn = gr.Button("Download Chunks as JSON", variant="primary") |
| | baseline_download_output = gr.File(label="Download") |
| | |
| | with gr.Column(): |
| | gr.Markdown("#### Fine-tuned Collection") |
| | finetuned_file_dropdown = gr.Dropdown( |
| | label="Select File to Inspect", |
| | choices=[], |
| | interactive=True |
| | ) |
| | finetuned_refresh_files = gr.Button("Refresh File List", variant="secondary") |
| | finetuned_chunks_display = gr.JSON(label="Chunk Details") |
| | finetuned_download_btn = gr.Button("Download Chunks as JSON", variant="primary") |
| | finetuned_download_output = gr.File(label="Download") |
| | |
| | |
| | baseline_refresh_files.click( |
| | lambda: gr.Dropdown(choices=get_files_list_baseline()), |
| | outputs=baseline_file_dropdown |
| | ) |
| | baseline_file_dropdown.change( |
| | get_chunks_for_file_baseline, |
| | inputs=baseline_file_dropdown, |
| | outputs=baseline_chunks_display |
| | ) |
| | baseline_download_btn.click( |
| | download_chunks_baseline, |
| | inputs=baseline_file_dropdown, |
| | outputs=baseline_download_output |
| | ) |
| | |
| | finetuned_refresh_files.click( |
| | lambda: gr.Dropdown(choices=get_files_list_finetuned()), |
| | outputs=finetuned_file_dropdown |
| | ) |
| | finetuned_file_dropdown.change( |
| | get_chunks_for_file_finetuned, |
| | inputs=finetuned_file_dropdown, |
| | outputs=finetuned_chunks_display |
| | ) |
| | finetuned_download_btn.click( |
| | download_chunks_finetuned, |
| | inputs=finetuned_file_dropdown, |
| | outputs=finetuned_download_output |
| | ) |
| | |
| | |
| | with gr.Tab("2. Comparison Search (Note: Semantic search is sensitive to query phrasing)"): |
| | gr.Markdown("### Side-by-Side Retrieval Comparison") |
| | search_query = gr.Textbox(label="Search Query", placeholder="e.g., 'Flask route decorator'") |
| | compare_btn = gr.Button("Compare Models", variant="primary") |
| | |
| | with gr.Row(): |
| | with gr.Column(): |
| | gr.Markdown("<div class='comparison-header'>BASELINE (CodeBERT)</div>", elem_classes="comparison-header") |
| | baseline_results = gr.Dataframe(headers=["File", "Score", "Code Snippet"], datatype=["str", "str", "str"], interactive=False, wrap=True) |
| | |
| | with gr.Column(): |
| | gr.Markdown("<div class='comparison-header'>FINE-TUNED (MRL-Enhanced)</div>", elem_classes="comparison-header") |
| | finetuned_results = gr.Dataframe(headers=["File", "Score", "Code Snippet"], datatype=["str", "str", "str"], interactive=False, wrap=True) |
| | |
| | compare_btn.click(search_comparison, inputs=search_query, outputs=[baseline_results, finetuned_results]) |
| | |
| | |
| | |
| | with gr.Tab("3. Code Similarity Search"): |
| | gr.Markdown("### Find Similar Code Snippets") |
| | gr.Markdown("Paste a code snippet to find similar code in the database") |
| | |
| | with gr.Row(): |
| | with gr.Column(): |
| | code_input = gr.Code(label="Paste Code Snippet", language="python", lines=10) |
| | similarity_btn = gr.Button("Find Similar Code", variant="primary") |
| | |
| | with gr.Column(): |
| | gr.Markdown("#### Search Settings") |
| | top_k_slider = gr.Slider(minimum=1, maximum=20, value=5, step=1, label="Number of Results") |
| | model_choice = gr.Radio(["Baseline", "Fine-tuned", "Both"], value="Both", label="Model to Use") |
| | |
| | gr.Markdown("### Results") |
| | |
| | with gr.Row(): |
| | with gr.Column(): |
| | gr.Markdown("#### Baseline Results") |
| | baseline_code_results = gr.Dataframe( |
| | headers=["File", "Similarity", "Code Snippet"], |
| | datatype=["str", "str", "str"], |
| | interactive=False, |
| | wrap=True, |
| | value=[["No search yet", "-", "-"]] |
| | ) |
| | |
| | with gr.Column(): |
| | gr.Markdown("#### Fine-tuned Results") |
| | finetuned_code_results = gr.Dataframe( |
| | headers=["File", "Similarity", "Code Snippet"], |
| | datatype=["str", "str", "str"], |
| | interactive=False, |
| | wrap=True, |
| | value=[["No search yet", "-", "-"]] |
| | ) |
| | |
| | def search_similar_code(code_snippet, top_k, model_choice): |
| | if not code_snippet or len(code_snippet.strip()) == 0: |
| | empty = [["Enter code to search", "-", "-"]] |
| | return empty, empty |
| | |
| | baseline_res = [] |
| | finetuned_res = [] |
| | |
| | if model_choice in ["Baseline", "Both"]: |
| | baseline_res = search_baseline(code_snippet, top_k) |
| | if not baseline_res: |
| | baseline_res = [["No results found", "-", "-"]] |
| | |
| | if model_choice in ["Fine-tuned", "Both"]: |
| | finetuned_res = search_finetuned(code_snippet, top_k) |
| | if not finetuned_res: |
| | finetuned_res = [["No results found", "-", "-"]] |
| | |
| | if model_choice == "Baseline": |
| | finetuned_res = [["Not searched", "-", "-"]] |
| | elif model_choice == "Fine-tuned": |
| | baseline_res = [["Not searched", "-", "-"]] |
| | |
| | return baseline_res, finetuned_res |
| | |
| | similarity_btn.click( |
| | search_similar_code, |
| | inputs=[code_input, top_k_slider, model_choice], |
| | outputs=[baseline_code_results, finetuned_code_results] |
| | ) |
| | |
| | |
| | with gr.Tab("4. Deployment Monitoring"): |
| | gr.Markdown("### Embedding Quality Analysis") |
| | gr.Markdown("Analyze the semantic space distribution and diversity of embeddings") |
| | |
| | with gr.Row(): |
| | with gr.Column(): |
| | gr.Markdown("#### Baseline Model") |
| | analyze_baseline_btn = gr.Button("Analyze Baseline Embeddings", variant="secondary") |
| | baseline_metrics = gr.Textbox(label="Baseline Metrics") |
| | baseline_plot = gr.Image() |
| | analyze_baseline_btn.click(analyze_embeddings_baseline, inputs=[], outputs=[baseline_metrics, baseline_plot]) |
| | |
| | with gr.Column(): |
| | gr.Markdown("#### Fine-tuned Model") |
| | analyze_finetuned_btn = gr.Button("Analyze Fine-tuned Embeddings", variant="secondary") |
| | finetuned_metrics = gr.Textbox(label="Fine-tuned Metrics") |
| | finetuned_plot = gr.Image() |
| | analyze_finetuned_btn.click(analyze_embeddings_finetuned, inputs=[], outputs=[finetuned_metrics, finetuned_plot]) |
| | |
| | gr.Markdown("---") |
| | gr.Markdown("### Retrieval Performance Evaluation") |
| | gr.Markdown("Evaluate retrieval accuracy using synthetic queries (query = first 3 lines of code)") |
| | |
| | eval_size = gr.Slider(minimum=10, maximum=500, value=50, step=10, label="Sample Size (Chunks to Evaluate)") |
| | |
| | with gr.Row(): |
| | with gr.Column(): |
| | gr.Markdown("#### Baseline Evaluation") |
| | eval_baseline_btn = gr.Button("Run Baseline Evaluation", variant="primary") |
| | baseline_eval_output = gr.Textbox(label="Baseline Results") |
| | eval_baseline_btn.click(evaluate_retrieval_baseline, inputs=[eval_size], outputs=baseline_eval_output) |
| | |
| | with gr.Column(): |
| | gr.Markdown("#### Fine-tuned Evaluation") |
| | eval_finetuned_btn = gr.Button("Run Fine-tuned Evaluation", variant="primary") |
| | finetuned_eval_output = gr.Textbox(label="Fine-tuned Results") |
| | eval_finetuned_btn.click(evaluate_retrieval_finetuned, inputs=[eval_size], outputs=finetuned_eval_output) |
| |
|
| | if __name__ == "__main__": |
| | demo.queue().launch(server_name="0.0.0.0", server_port=7860, share=False) |
| |
|