CodeMode / app.py
CodeMode Agent
Deploy CodeMode via Agent
130e2b9
import gradio as gr
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel
import pandas as pd
import sys
import os
import shutil
from pathlib import Path
import chromadb
from chromadb.config import Settings
import uuid
import tempfile
# --- Add scripts to path ---
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
from scripts.core.ingestion.ingest import GitCrawler
from scripts.core.ingestion.chunk import RepoChunker
# --- Configuration ---
BASELINE_MODEL = "microsoft/codebert-base"
FINETUNED_MODEL = "shubharuidas/codebert-base-code-embed-mrl-langchain-langgraph"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DB_DIR = Path(os.path.abspath("data/chroma_db_comparison"))
DB_DIR.mkdir(parents=True, exist_ok=True)
print(f"Loading models on {DEVICE}...")
print("1. Loading baseline model...")
baseline_tokenizer = AutoTokenizer.from_pretrained(BASELINE_MODEL)
baseline_model = AutoModel.from_pretrained(BASELINE_MODEL)
baseline_model.to(DEVICE)
baseline_model.eval()
print("2. Loading fine-tuned model...")
finetuned_tokenizer = AutoTokenizer.from_pretrained(FINETUNED_MODEL)
finetuned_model = AutoModel.from_pretrained(FINETUNED_MODEL)
finetuned_model.to(DEVICE)
finetuned_model.eval()
print("Both models loaded!")
# --- ChromaDB Setup ---
chroma_client = chromadb.PersistentClient(path=str(DB_DIR))
baseline_collection = chroma_client.get_or_create_collection(name="baseline_rag", metadata={"hnsw:space": "cosine"})
finetuned_collection = chroma_client.get_or_create_collection(name="finetuned_rag", metadata={"hnsw:space": "cosine"})
# --- Embedding Functions ---
def compute_baseline_embeddings(text_list):
if not text_list: return None
inputs = baseline_tokenizer(text_list, return_tensors="pt", padding=True, truncation=True, max_length=512).to(DEVICE)
with torch.no_grad():
out = baseline_model(**inputs)
emb = out.last_hidden_state.mean(dim=1)
return F.normalize(emb, p=2, dim=1)
def compute_finetuned_embeddings(text_list):
if not text_list: return None
inputs = finetuned_tokenizer(text_list, return_tensors="pt", padding=True, truncation=True, max_length=512).to(DEVICE)
with torch.no_grad():
out = finetuned_model(**inputs)
emb = out.last_hidden_state.mean(dim=1)
return F.normalize(emb, p=2, dim=1)
# --- Reset Functions ---
def reset_baseline():
chroma_client.delete_collection("baseline_rag")
global baseline_collection
baseline_collection = chroma_client.get_or_create_collection(name="baseline_rag", metadata={"hnsw:space": "cosine"})
return "Baseline database reset."
def reset_finetuned():
chroma_client.delete_collection("finetuned_rag")
global finetuned_collection
finetuned_collection = chroma_client.get_or_create_collection(name="finetuned_rag", metadata={"hnsw:space": "cosine"})
return "Fine-tuned database reset."
# --- Database Inspector Functions ---
def list_baseline_files():
count = baseline_collection.count()
if count == 0:
return [["No data indexed yet", "-", "-"]]
try:
data = baseline_collection.get(limit=min(count, 1000), include=["metadatas"])
file_stats = {}
for meta in data['metadatas']:
fname = meta.get("file_name", "unknown")
url = meta.get("url", "unknown")
if fname not in file_stats:
file_stats[fname] = {"count": 0, "url": url}
file_stats[fname]["count"] += 1
results = [[fname, stats["count"], stats["url"]] for fname, stats in file_stats.items()]
return sorted(results, key=lambda x: x[1], reverse=True)
except Exception as e:
return [[f"Error: {str(e)}", "-", "-"]]
def list_finetuned_files():
count = finetuned_collection.count()
if count == 0:
return [["No data indexed yet", "-", "-"]]
try:
data = finetuned_collection.get(limit=min(count, 1000), include=["metadatas"])
file_stats = {}
for meta in data['metadatas']:
fname = meta.get("file_name", "unknown")
url = meta.get("url", "unknown")
if fname not in file_stats:
file_stats[fname] = {"count": 0, "url": url}
file_stats[fname]["count"] += 1
results = [[fname, stats["count"], stats["url"]] for fname, stats in file_stats.items()]
return sorted(results, key=lambda x: x[1], reverse=True)
except Exception as e:
return [[f"Error: {str(e)}", "-", "-"]]
# --- Chunk Inspector Functions ---
def get_files_list_baseline():
"""Get list of unique files in baseline collection"""
try:
data = baseline_collection.get(include=["metadatas"])
if not data['metadatas']:
return []
files = list(set([m.get("file_name", "unknown") for m in data['metadatas']]))
return sorted(files)
except Exception as e:
return []
def get_files_list_finetuned():
"""Get list of unique files in fine-tuned collection"""
try:
data = finetuned_collection.get(include=["metadatas"])
if not data['metadatas']:
return []
files = list(set([m.get("file_name", "unknown") for m in data['metadatas']]))
return sorted(files)
except Exception as e:
return []
def get_chunks_for_file_baseline(file_name):
"""Get all chunks for a specific file from baseline collection"""
if not file_name:
return {"error": "No file selected"}
try:
# Get all data and filter in Python
data = baseline_collection.get(
include=["documents", "metadatas", "embeddings"]
)
if not data['documents']:
return {"error": "No chunks found"}
# Filter by file_name in Python
chunks = []
for i, (doc, meta, emb) in enumerate(zip(data['documents'], data['metadatas'], data['embeddings'])):
if meta.get("file_name") == file_name:
chunks.append({
"chunk_id": len(chunks) + 1,
"content": doc[:500] + "..." if len(doc) > 500 else doc,
"full_length": len(doc),
"metadata": meta,
"embedding_dim": len(emb) if emb is not None else 0
})
if not chunks:
return {"error": "No chunks found for this file"}
return {
"file_name": file_name,
"total_chunks": len(chunks),
"chunks": chunks
}
except Exception as e:
import traceback
error_details = traceback.format_exc()
print(f"ERROR in get_chunks_for_file_baseline: {error_details}")
return {"error": str(e)}
def get_chunks_for_file_finetuned(file_name):
"""Get all chunks for a specific file from fine-tuned collection"""
if not file_name:
return {"error": "No file selected"}
try:
# Get all data and filter in Python
data = finetuned_collection.get(
include=["documents", "metadatas", "embeddings"]
)
if not data['documents']:
return {"error": "No chunks found"}
# Filter by file_name in Python
chunks = []
for i, (doc, meta, emb) in enumerate(zip(data['documents'], data['metadatas'], data['embeddings'])):
if meta.get("file_name") == file_name:
chunks.append({
"chunk_id": len(chunks) + 1,
"content": doc[:500] + "..." if len(doc) > 500 else doc,
"full_length": len(doc),
"metadata": meta,
"embedding_dim": len(emb) if emb is not None else 0
})
if not chunks:
return {"error": "No chunks found for this file"}
return {
"file_name": file_name,
"total_chunks": len(chunks),
"chunks": chunks
}
except Exception as e:
return {"error": str(e)}
def download_chunks_baseline(file_name):
"""Export chunks to JSON file for baseline"""
if not file_name:
return None
import json
import tempfile
chunks_data = get_chunks_for_file_baseline(file_name)
temp_file = tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.json')
json.dump(chunks_data, temp_file, indent=2)
temp_file.close()
return temp_file.name
def download_chunks_finetuned(file_name):
"""Export chunks to JSON file for fine-tuned"""
if not file_name:
return None
import json
import tempfile
chunks_data = get_chunks_for_file_finetuned(file_name)
temp_file = tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.json')
json.dump(chunks_data, temp_file, indent=2)
temp_file.close()
return temp_file.name
# --- Search Functions ---
def search_baseline(query, top_k=5):
if baseline_collection.count() == 0: return []
query_emb = compute_baseline_embeddings([query])
if query_emb is None: return []
query_vec = query_emb.cpu().numpy().tolist()[0]
results = baseline_collection.query(query_embeddings=[query_vec], n_results=min(top_k, baseline_collection.count()), include=["metadatas", "documents", "distances"])
output = []
if results['ids']:
for i in range(len(results['ids'][0])):
meta = results['metadatas'][0][i]
code = results['documents'][0][i]
dist = results['distances'][0][i]
score = 1 - dist
output.append([meta.get("file_name", "unknown"), f"{score:.4f}", code[:300] + "..."])
return output
def search_finetuned(query, top_k=5):
if finetuned_collection.count() == 0: return []
query_emb = compute_finetuned_embeddings([query])
if query_emb is None: return []
query_vec = query_emb.cpu().numpy().tolist()[0]
results = finetuned_collection.query(query_embeddings=[query_vec], n_results=min(top_k, finetuned_collection.count()), include=["metadatas", "documents", "distances"])
output = []
if results['ids']:
for i in range(len(results['ids'][0])):
meta = results['metadatas'][0][i]
code = results['documents'][0][i]
dist = results['distances'][0][i]
score = 1 - dist
output.append([meta.get("file_name", "unknown"), f"{score:.4f}", code[:300] + "..."])
return output
def search_comparison(query, top_k=5):
baseline_results = search_baseline(query, top_k)
finetuned_results = search_finetuned(query, top_k)
return baseline_results, finetuned_results
# --- Ingestion Functions ---
def ingest_from_url(repo_url):
if not repo_url.startswith("http"):
yield "Invalid URL"
return
DATA_DIR = Path(os.path.abspath("data/raw_ingest"))
import stat
def remove_readonly(func, path, _):
os.chmod(path, stat.S_IWRITE)
func(path)
try:
if DATA_DIR.exists():
shutil.rmtree(DATA_DIR, onerror=remove_readonly)
yield f"Cloning {repo_url}..."
crawler = GitCrawler(cache_dir=DATA_DIR)
repo_path = crawler.clone_repository(repo_url)
if not repo_path:
yield "Failed to clone repository."
return
yield "Listing files..."
files = crawler.list_files(repo_path, extensions={'.py', '.md', '.json', '.js', '.ts', '.java', '.cpp'})
if isinstance(files, tuple): files = [f.path for f in files[0]]
total_files = len(files)
yield f"Found {total_files} files. Chunking..."
chunker = RepoChunker()
all_chunks = []
for i, file_path in enumerate(files):
yield f"Chunking: {i+1}/{total_files} ({file_path.name})"
try:
meta = {"file_name": file_path.name, "url": repo_url}
file_chunks = chunker.chunk_file(file_path, repo_metadata=meta)
all_chunks.extend(file_chunks)
except Exception as e:
print(f"Skipping {file_path}: {e}")
if not all_chunks:
yield "No valid chunks found."
return
total_chunks = len(all_chunks)
yield f"Generated {total_chunks} chunks. Embedding (BASELINE)..."
batch_size = 64
# Index with baseline
for i in range(0, total_chunks, batch_size):
batch = all_chunks[i:i+batch_size]
texts = [c.code for c in batch]
ids = [str(uuid.uuid4()) for _ in batch]
metadatas = [{"file_name": Path(c.file_path).name, "url": repo_url} for c in batch]
embeddings = compute_baseline_embeddings(texts)
if embeddings is not None:
baseline_collection.add(ids=ids, embeddings=embeddings.cpu().numpy().tolist(), metadatas=metadatas, documents=texts)
yield f"Baseline: {min(i+batch_size, total_chunks)}/{total_chunks}"
yield f"Embedding (FINE-TUNED)..."
# Index with fine-tuned
for i in range(0, total_chunks, batch_size):
batch = all_chunks[i:i+batch_size]
texts = [c.code for c in batch]
ids = [str(uuid.uuid4()) for _ in batch]
metadatas = [{"file_name": Path(c.file_path).name, "url": repo_url} for c in batch]
embeddings = compute_finetuned_embeddings(texts)
if embeddings is not None:
finetuned_collection.add(ids=ids, embeddings=embeddings.cpu().numpy().tolist(), metadatas=metadatas, documents=texts)
yield f"Fine-tuned: {min(i+batch_size, total_chunks)}/{total_chunks}"
yield f"SUCCESS! Indexed {total_chunks} chunks in both databases."
except Exception as e:
import traceback
traceback.print_exc()
yield f"Error: {str(e)}"
def ingest_from_files(files):
if not files or len(files) == 0:
yield "No files uploaded."
return
try:
yield f"Processing {len(files)} file(s)..."
chunker = RepoChunker()
all_chunks = []
for i, file in enumerate(files):
yield f"Chunking file {i+1}/{len(files)}: {Path(file.name).name}"
try:
# Gradio file upload: file.name contains the temp path
file_path = Path(file.name)
meta = {"file_name": file_path.name, "url": "uploaded"}
file_chunks = chunker.chunk_file(file_path, repo_metadata=meta)
all_chunks.extend(file_chunks)
except Exception as e:
yield f"Error chunking {Path(file.name).name}: {str(e)}"
import traceback
traceback.print_exc()
if not all_chunks:
yield "No valid chunks found."
return
total_chunks = len(all_chunks)
yield f"Generated {total_chunks} chunks. Embedding (BASELINE)..."
batch_size = 64
for i in range(0, total_chunks, batch_size):
batch = all_chunks[i:i+batch_size]
texts = [c.code for c in batch]
ids = [str(uuid.uuid4()) for _ in batch]
metadatas = [{"file_name": Path(c.file_path).name, "url": "uploaded"} for c in batch]
embeddings = compute_baseline_embeddings(texts)
if embeddings is not None:
baseline_collection.add(ids=ids, embeddings=embeddings.cpu().numpy().tolist(), metadatas=metadatas, documents=texts)
yield f"Baseline: {min(i+batch_size, total_chunks)}/{total_chunks}"
yield f"Embedding (FINE-TUNED)..."
for i in range(0, total_chunks, batch_size):
batch = all_chunks[i:i+batch_size]
texts = [c.code for c in batch]
ids = [str(uuid.uuid4()) for _ in batch]
metadatas = [{"file_name": Path(c.file_path).name, "url": "uploaded"} for c in batch]
embeddings = compute_finetuned_embeddings(texts)
if embeddings is not None:
finetuned_collection.add(ids=ids, embeddings=embeddings.cpu().numpy().tolist(), metadatas=metadatas, documents=texts)
yield f"Fine-tuned: {min(i+batch_size, total_chunks)}/{total_chunks}"
yield f"SUCCESS! Indexed {total_chunks} chunks from uploaded files."
except Exception as e:
import traceback
traceback.print_exc()
yield f"Error: {str(e)}"
# --- Analysis & Evaluation Functions ---
def analyze_embeddings_baseline():
count = baseline_collection.count()
if count < 5:
return "Not enough data (Need > 5 chunks).", None
try:
limit = min(count, 2000)
data = baseline_collection.get(limit=limit, include=["embeddings", "metadatas"])
X = torch.tensor(data['embeddings'])
X_mean = torch.mean(X, 0)
X_centered = X - X_mean
U, S, V = torch.pca_lowrank(X_centered, q=2)
projected = torch.matmul(X_centered, V[:, :2]).numpy()
indices = torch.randint(0, len(X), (min(100, len(X)),))
sample = X[indices]
sim_matrix = torch.mm(sample, sample.t())
mask = ~torch.eye(len(sample), dtype=bool)
avg_sim = sim_matrix[mask].mean().item()
diversity_score = 1.0 - avg_sim
metrics = (
f"BASELINE MODEL\n"
f"Total Chunks: {count}\n"
f"Analyzed: {len(X)}\n"
f"Diversity Score: {diversity_score:.4f}\n"
f"Avg Similarity: {avg_sim:.4f}"
)
plot_df = pd.DataFrame({
"x": projected[:, 0],
"y": projected[:, 1],
"topic": [m.get("file_name", "unknown") for m in data['metadatas']]
})
import matplotlib.pyplot as plt
import io
from PIL import Image
# Create matplotlib figure with proper spacing
fig, ax = plt.subplots(figsize=(10, 8))
fig.subplots_adjust(top=0.92) # Add space for title
# Plot each file with different color
unique_topics = plot_df["topic"].unique()
for topic in unique_topics:
mask = plot_df["topic"] == topic
ax.scatter(plot_df[mask]["x"], plot_df[mask]["y"], label=topic, alpha=0.6, s=50)
ax.set_xlabel("PC1")
ax.set_ylabel("PC2")
ax.set_title("Baseline Semantic Space (PCA)", fontsize=14, pad=20)
ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=8)
ax.grid(True, alpha=0.3)
plt.tight_layout()
# Convert to image for Gradio
buf = io.BytesIO()
plt.savefig(buf, format='png', dpi=100, bbox_inches='tight')
buf.seek(0)
img = Image.open(buf)
plt.close()
return metrics, img
except Exception as e:
import traceback
traceback.print_exc()
return f"Error: {e}", None
def analyze_embeddings_finetuned():
count = finetuned_collection.count()
if count < 5:
return "Not enough data (Need > 5 chunks).", None
try:
limit = min(count, 2000)
data = finetuned_collection.get(limit=limit, include=["embeddings", "metadatas"])
X = torch.tensor(data['embeddings'])
X_mean = torch.mean(X, 0)
X_centered = X - X_mean
U, S, V = torch.pca_lowrank(X_centered, q=2)
projected = torch.matmul(X_centered, V[:, :2]).numpy()
indices = torch.randint(0, len(X), (min(100, len(X)),))
sample = X[indices]
sim_matrix = torch.mm(sample, sample.t())
mask = ~torch.eye(len(sample), dtype=bool)
avg_sim = sim_matrix[mask].mean().item()
diversity_score = 1.0 - avg_sim
metrics = (
f"FINE-TUNED MODEL\n"
f"Total Chunks: {count}\n"
f"Analyzed: {len(X)}\n"
f"Diversity Score: {diversity_score:.4f}\n"
f"Avg Similarity: {avg_sim:.4f}"
)
plot_df = pd.DataFrame({
"x": projected[:, 0],
"y": projected[:, 1],
"topic": [m.get("file_name", "unknown") for m in data['metadatas']]
})
import matplotlib.pyplot as plt
import io
from PIL import Image
# Create matplotlib figure with proper spacing
fig, ax = plt.subplots(figsize=(10, 8))
fig.subplots_adjust(top=0.92) # Add space for title
# Plot each file with different color
unique_topics = plot_df["topic"].unique()
for topic in unique_topics:
mask = plot_df["topic"] == topic
ax.scatter(plot_df[mask]["x"], plot_df[mask]["y"], label=topic, alpha=0.6, s=50)
ax.set_xlabel("PC1")
ax.set_ylabel("PC2")
ax.set_title("Fine-tuned Semantic Space (PCA)", fontsize=14, pad=20)
ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=8)
ax.grid(True, alpha=0.3)
plt.tight_layout()
# Convert to image for Gradio
buf = io.BytesIO()
plt.savefig(buf, format='png', dpi=100, bbox_inches='tight')
buf.seek(0)
img = Image.open(buf)
plt.close()
return metrics, img
except Exception as e:
import traceback
traceback.print_exc()
return f"Error: {e}", None
def evaluate_retrieval_baseline(sample_limit):
count = baseline_collection.count()
if count < 10: return "Not enough data for evaluation (Need > 10 chunks)."
try:
fetch_limit = min(count, 2000)
data = baseline_collection.get(limit=fetch_limit, include=["documents"])
import random
actual_sample_size = min(sample_limit, len(data['ids']))
sample_indices = random.sample(range(len(data['ids'])), actual_sample_size)
hits_at_1 = 0
hits_at_5 = 0
mrr_sum = 0
yield f"BASELINE: Evaluating {actual_sample_size} chunks..."
for i, idx in enumerate(sample_indices):
target_id = data['ids'][idx]
code = data['documents'][idx]
query = "\n".join(code.split("\n")[:3])
query_emb = compute_baseline_embeddings([query]).cpu().numpy().tolist()[0]
results = baseline_collection.query(query_embeddings=[query_emb], n_results=10)
found_ids = results['ids'][0]
if target_id in found_ids:
rank = found_ids.index(target_id) + 1
mrr_sum += 1.0 / rank
if rank == 1: hits_at_1 += 1
if rank <= 5: hits_at_5 += 1
if i % 10 == 0:
yield f"Baseline: {i}/{actual_sample_size}..."
recall_1 = hits_at_1 / actual_sample_size
recall_5 = hits_at_5 / actual_sample_size
mrr = mrr_sum / actual_sample_size
report = (
f"BASELINE EVALUATION ({actual_sample_size} chunks)\n"
f"{'='*40}\n"
f"Recall@1: {recall_1:.4f}\n"
f"Recall@5: {recall_5:.4f}\n"
f"MRR: {mrr:.4f}"
)
yield report
except Exception as e:
import traceback
traceback.print_exc()
yield f"Error: {e}"
def evaluate_retrieval_finetuned(sample_limit):
count = finetuned_collection.count()
if count < 10: return "Not enough data for evaluation (Need > 10 chunks)."
try:
fetch_limit = min(count, 2000)
data = finetuned_collection.get(limit=fetch_limit, include=["documents"])
import random
actual_sample_size = min(sample_limit, len(data['ids']))
sample_indices = random.sample(range(len(data['ids'])), actual_sample_size)
hits_at_1 = 0
hits_at_5 = 0
mrr_sum = 0
yield f"FINE-TUNED: Evaluating {actual_sample_size} chunks..."
for i, idx in enumerate(sample_indices):
target_id = data['ids'][idx]
code = data['documents'][idx]
query = "\n".join(code.split("\n")[:3])
query_emb = compute_finetuned_embeddings([query]).cpu().numpy().tolist()[0]
results = finetuned_collection.query(query_embeddings=[query_emb], n_results=10)
found_ids = results['ids'][0]
if target_id in found_ids:
rank = found_ids.index(target_id) + 1
mrr_sum += 1.0 / rank
if rank == 1: hits_at_1 += 1
if rank <= 5: hits_at_5 += 1
if i % 10 == 0:
yield f"Fine-tuned: {i}/{actual_sample_size}..."
recall_1 = hits_at_1 / actual_sample_size
recall_5 = hits_at_5 / actual_sample_size
mrr = mrr_sum / actual_sample_size
report = (
f"FINE-TUNED EVALUATION ({actual_sample_size} chunks)\n"
f"{'='*40}\n"
f"Recall@1: {recall_1:.4f}\n"
f"Recall@5: {recall_5:.4f}\n"
f"MRR: {mrr:.4f}"
)
yield report
except Exception as e:
import traceback
traceback.print_exc()
yield f"Error: {e}"
# --- UI ---
theme = gr.themes.Soft(primary_hue="slate", neutral_hue="slate", spacing_size="sm", radius_size="md").set(body_background_fill="*neutral_50", block_background_fill="white", block_border_width="1px", block_title_text_weight="600")
css = """
h1 { text-align: center; font-family: 'Inter', sans-serif; margin-bottom: 1rem; color: #1e293b; }
.gradio-container { max-width: 1400px !important; margin: auto; }
.comparison-header { font-size: 1.1rem; font-weight: 600; color: #334155; text-align: center; padding: 0.5rem; }
"""
with gr.Blocks(theme=theme, css=css, title="CodeMode - Baseline vs Fine-tuned") as demo:
gr.Markdown("# CodeMode: Baseline vs Fine-tuned Model Comparison")
gr.Markdown("Compare retrieval performance between **microsoft/codebert-base** (baseline) and **MRL-enhanced fine-tuned** model")
with gr.Tabs():
# TAB 1: INGEST
with gr.Tab("1. Ingest Code"):
with gr.Tabs():
with gr.Tab("GitHub Repository"):
repo_input = gr.Textbox(label="GitHub URL", placeholder="https://github.com/pallets/flask")
ingest_url_btn = gr.Button("Ingest from URL", variant="primary")
url_status = gr.Textbox(label="Status")
ingest_url_btn.click(ingest_from_url, inputs=repo_input, outputs=url_status)
with gr.Tab("Upload Python Files"):
file_upload = gr.File(label="Upload .py files", file_types=[".py"], file_count="multiple")
ingest_files_btn = gr.Button("Ingest Uploaded Files", variant="primary")
upload_status = gr.Textbox(label="Status")
ingest_files_btn.click(ingest_from_files, inputs=file_upload, outputs=upload_status)
with gr.Row():
reset_baseline_btn = gr.Button("Reset Baseline DB", variant="stop")
reset_finetuned_btn = gr.Button("Reset Fine-tuned DB", variant="stop")
reset_status = gr.Textbox(label="Reset Status")
reset_baseline_btn.click(reset_baseline, inputs=[], outputs=reset_status)
reset_finetuned_btn.click(reset_finetuned, inputs=[], outputs=reset_status)
gr.Markdown("---")
gr.Markdown("### Database Inspector")
gr.Markdown("View indexed files in each collection")
with gr.Row():
with gr.Column():
gr.Markdown("#### Baseline Collection")
inspect_baseline_btn = gr.Button("Inspect Baseline DB", variant="secondary")
baseline_files_df = gr.Dataframe(
headers=["File Name", "Chunks", "Source URL"],
datatype=["str", "number", "str"],
interactive=False,
value=[["No data yet", "-", "-"]]
)
inspect_baseline_btn.click(list_baseline_files, inputs=[], outputs=baseline_files_df)
with gr.Column():
gr.Markdown("#### Fine-tuned Collection")
inspect_finetuned_btn = gr.Button("Inspect Fine-tuned DB", variant="secondary")
finetuned_files_df = gr.Dataframe(
headers=["File Name", "Chunks", "Source URL"],
datatype=["str", "number", "str"],
interactive=False,
value=[["No data yet", "-", "-"]]
)
inspect_finetuned_btn.click(list_finetuned_files, inputs=[], outputs=finetuned_files_df)
gr.Markdown("---")
gr.Markdown("### Chunk Inspector")
gr.Markdown("View detailed chunk information for indexed files (content, metadata, schema)")
with gr.Row():
with gr.Column():
gr.Markdown("#### Baseline Collection")
baseline_file_dropdown = gr.Dropdown(
label="Select File to Inspect",
choices=[],
interactive=True
)
baseline_refresh_files = gr.Button("Refresh File List", variant="secondary")
baseline_chunks_display = gr.JSON(label="Chunk Details")
baseline_download_btn = gr.Button("Download Chunks as JSON", variant="primary")
baseline_download_output = gr.File(label="Download")
with gr.Column():
gr.Markdown("#### Fine-tuned Collection")
finetuned_file_dropdown = gr.Dropdown(
label="Select File to Inspect",
choices=[],
interactive=True
)
finetuned_refresh_files = gr.Button("Refresh File List", variant="secondary")
finetuned_chunks_display = gr.JSON(label="Chunk Details")
finetuned_download_btn = gr.Button("Download Chunks as JSON", variant="primary")
finetuned_download_output = gr.File(label="Download")
# Wire up Chunk Inspector events
baseline_refresh_files.click(
lambda: gr.Dropdown(choices=get_files_list_baseline()),
outputs=baseline_file_dropdown
)
baseline_file_dropdown.change(
get_chunks_for_file_baseline,
inputs=baseline_file_dropdown,
outputs=baseline_chunks_display
)
baseline_download_btn.click(
download_chunks_baseline,
inputs=baseline_file_dropdown,
outputs=baseline_download_output
)
finetuned_refresh_files.click(
lambda: gr.Dropdown(choices=get_files_list_finetuned()),
outputs=finetuned_file_dropdown
)
finetuned_file_dropdown.change(
get_chunks_for_file_finetuned,
inputs=finetuned_file_dropdown,
outputs=finetuned_chunks_display
)
finetuned_download_btn.click(
download_chunks_finetuned,
inputs=finetuned_file_dropdown,
outputs=finetuned_download_output
)
# TAB 2: COMPARISON SEARCH
with gr.Tab("2. Comparison Search (Note: Semantic search is sensitive to query phrasing)"):
gr.Markdown("### Side-by-Side Retrieval Comparison")
search_query = gr.Textbox(label="Search Query", placeholder="e.g., 'Flask route decorator'")
compare_btn = gr.Button("Compare Models", variant="primary")
with gr.Row():
with gr.Column():
gr.Markdown("<div class='comparison-header'>BASELINE (CodeBERT)</div>", elem_classes="comparison-header")
baseline_results = gr.Dataframe(headers=["File", "Score", "Code Snippet"], datatype=["str", "str", "str"], interactive=False, wrap=True)
with gr.Column():
gr.Markdown("<div class='comparison-header'>FINE-TUNED (MRL-Enhanced)</div>", elem_classes="comparison-header")
finetuned_results = gr.Dataframe(headers=["File", "Score", "Code Snippet"], datatype=["str", "str", "str"], interactive=False, wrap=True)
compare_btn.click(search_comparison, inputs=search_query, outputs=[baseline_results, finetuned_results])
# TAB 3: CODE SIMILARITY SEARCH
with gr.Tab("3. Code Similarity Search"):
gr.Markdown("### Find Similar Code Snippets")
gr.Markdown("Paste a code snippet to find similar code in the database")
with gr.Row():
with gr.Column():
code_input = gr.Code(label="Paste Code Snippet", language="python", lines=10)
similarity_btn = gr.Button("Find Similar Code", variant="primary")
with gr.Column():
gr.Markdown("#### Search Settings")
top_k_slider = gr.Slider(minimum=1, maximum=20, value=5, step=1, label="Number of Results")
model_choice = gr.Radio(["Baseline", "Fine-tuned", "Both"], value="Both", label="Model to Use")
gr.Markdown("### Results")
with gr.Row():
with gr.Column():
gr.Markdown("#### Baseline Results")
baseline_code_results = gr.Dataframe(
headers=["File", "Similarity", "Code Snippet"],
datatype=["str", "str", "str"],
interactive=False,
wrap=True,
value=[["No search yet", "-", "-"]]
)
with gr.Column():
gr.Markdown("#### Fine-tuned Results")
finetuned_code_results = gr.Dataframe(
headers=["File", "Similarity", "Code Snippet"],
datatype=["str", "str", "str"],
interactive=False,
wrap=True,
value=[["No search yet", "-", "-"]]
)
def search_similar_code(code_snippet, top_k, model_choice):
if not code_snippet or len(code_snippet.strip()) == 0:
empty = [["Enter code to search", "-", "-"]]
return empty, empty
baseline_res = []
finetuned_res = []
if model_choice in ["Baseline", "Both"]:
baseline_res = search_baseline(code_snippet, top_k)
if not baseline_res:
baseline_res = [["No results found", "-", "-"]]
if model_choice in ["Fine-tuned", "Both"]:
finetuned_res = search_finetuned(code_snippet, top_k)
if not finetuned_res:
finetuned_res = [["No results found", "-", "-"]]
if model_choice == "Baseline":
finetuned_res = [["Not searched", "-", "-"]]
elif model_choice == "Fine-tuned":
baseline_res = [["Not searched", "-", "-"]]
return baseline_res, finetuned_res
similarity_btn.click(
search_similar_code,
inputs=[code_input, top_k_slider, model_choice],
outputs=[baseline_code_results, finetuned_code_results]
)
# TAB 4: DEPLOYMENT MONITORING
with gr.Tab("4. Deployment Monitoring"):
gr.Markdown("### Embedding Quality Analysis")
gr.Markdown("Analyze the semantic space distribution and diversity of embeddings")
with gr.Row():
with gr.Column():
gr.Markdown("#### Baseline Model")
analyze_baseline_btn = gr.Button("Analyze Baseline Embeddings", variant="secondary")
baseline_metrics = gr.Textbox(label="Baseline Metrics")
baseline_plot = gr.Image()
analyze_baseline_btn.click(analyze_embeddings_baseline, inputs=[], outputs=[baseline_metrics, baseline_plot])
with gr.Column():
gr.Markdown("#### Fine-tuned Model")
analyze_finetuned_btn = gr.Button("Analyze Fine-tuned Embeddings", variant="secondary")
finetuned_metrics = gr.Textbox(label="Fine-tuned Metrics")
finetuned_plot = gr.Image()
analyze_finetuned_btn.click(analyze_embeddings_finetuned, inputs=[], outputs=[finetuned_metrics, finetuned_plot])
gr.Markdown("---")
gr.Markdown("### Retrieval Performance Evaluation")
gr.Markdown("Evaluate retrieval accuracy using synthetic queries (query = first 3 lines of code)")
eval_size = gr.Slider(minimum=10, maximum=500, value=50, step=10, label="Sample Size (Chunks to Evaluate)")
with gr.Row():
with gr.Column():
gr.Markdown("#### Baseline Evaluation")
eval_baseline_btn = gr.Button("Run Baseline Evaluation", variant="primary")
baseline_eval_output = gr.Textbox(label="Baseline Results")
eval_baseline_btn.click(evaluate_retrieval_baseline, inputs=[eval_size], outputs=baseline_eval_output)
with gr.Column():
gr.Markdown("#### Fine-tuned Evaluation")
eval_finetuned_btn = gr.Button("Run Fine-tuned Evaluation", variant="primary")
finetuned_eval_output = gr.Textbox(label="Fine-tuned Results")
eval_finetuned_btn.click(evaluate_retrieval_finetuned, inputs=[eval_size], outputs=finetuned_eval_output)
if __name__ == "__main__":
demo.queue().launch(server_name="0.0.0.0", server_port=7860, share=False)