import os import re import shutil # ============================================================ # ENV (set BEFORE transformers/hub usage) # ============================================================ os.environ.setdefault("HF_HOME", "/tmp/hf") os.environ.setdefault("HUGGINGFACE_HUB_CACHE", "/tmp/hf/hub") os.environ.setdefault("TRANSFORMERS_CACHE", "/tmp/hf/transformers") os.environ.setdefault("HF_HUB_DISABLE_XET", "1") # disable hf-xet if present os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") import torch import torch.nn as nn import torch.nn.functional as F import pandas as pd import gradio as gr from huggingface_hub import hf_hub_download from transformers import AutoConfig, AutoTokenizer, AutoModel from safetensors.torch import load_file # ----------------------------- # MODEL INITIALIZATION # ----------------------------- MODEL_NAME = "desklib/ai-text-detector-v1.01" tokenizer = None model = None device = torch.device("cuda" if torch.cuda.is_available() else "cpu") THRESHOLD = 0.59 def _build_error_card(msg: str) -> str: return ( "
" f"{msg}
" ) def wipe_model_cache(model_id: str) -> int: """ Delete cached files for this model from common HF cache locations. Returns number of cache directories removed. """ safe = model_id.replace("/", "--") candidates = [ # our /tmp cache (recommended) f"/tmp/hf/hub/models--{safe}", f"/tmp/hf/transformers/models--{safe}", # default home cache (in case something wrote there) os.path.expanduser(f"~/.cache/huggingface/hub/models--{safe}"), os.path.expanduser(f"~/.cache/huggingface/transformers/models--{safe}"), os.path.expanduser(f"~/.cache/huggingface/modules/models--{safe}"), ] removed = 0 for path in candidates: if os.path.exists(path): shutil.rmtree(path, ignore_errors=True) removed += 1 return removed class DesklibAIDetectionModel(nn.Module): """ Matches the architecture described by desklib: base transformer + mean pooling + linear classifier to 1 logit. The repo config lists "architectures": ["DesklibAIDetectionModel"]. :contentReference[oaicite:1]{index=1} """ def __init__(self, config): super().__init__() self.backbone = AutoModel.from_config(config) self.classifier = nn.Linear(config.hidden_size, 1) def forward(self, input_ids, attention_mask=None): outputs = self.backbone(input_ids=input_ids, attention_mask=attention_mask) last_hidden = outputs.last_hidden_state # (B, T, H) if attention_mask is None: pooled = last_hidden.mean(dim=1) else: mask = attention_mask.unsqueeze(-1).expand(last_hidden.size()).float() summed = torch.sum(last_hidden * mask, dim=1) denom = torch.clamp(mask.sum(dim=1), min=1e-9) pooled = summed / denom logits = self.classifier(pooled) # (B, 1) return logits def load_desklib_model(force_redownload: bool = False): """ Robust loader: - downloads config/tokenizer normally - downloads model.safetensors explicitly - loads safetensors via safetensors.torch.load_file - loads into our matching PyTorch module with strict=False """ global tokenizer, model if (not force_redownload) and tokenizer is not None and model is not None: return tokenizer, model if force_redownload: print("💣 NUKE requested: wiping cache + forcing fresh downloads...") removed = wipe_model_cache(MODEL_NAME) print(f"🧹 Cache dirs removed: {removed}") tokenizer = None model = None print(f"🚀 Loading tokenizer/config: {MODEL_NAME}") config = AutoConfig.from_pretrained(MODEL_NAME, force_download=force_redownload) tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, force_download=force_redownload) print("⬇️ Downloading model.safetensors explicitly...") weights_path = hf_hub_download( repo_id=MODEL_NAME, filename="model.safetensors", force_download=force_redownload, ) size_gb = os.path.getsize(weights_path) / (1024**3) print(f"✅ model.safetensors path: {weights_path}") print(f"✅ model.safetensors size: {size_gb:.2f} GB") # Build model + load weights print("🧠 Building DesklibAIDetectionModel + loading weights...") m = DesklibAIDetectionModel(config) state = load_file(weights_path) # this will throw if file is truly corrupt missing, unexpected = m.load_state_dict(state, strict=False) # Helpful debug (won't crash) if missing: print(f"⚠️ Missing keys (first 20): {missing[:20]}") if unexpected: print(f"⚠️ Unexpected keys (first 20): {unexpected[:20]}") model = m.to(device).eval() return tokenizer, model # ----------------------------- # UTILITIES # ----------------------------- ABBR = ["e.g", "i.e", "mr", "mrs", "ms", "dr", "prof", "vs", "etc", "fig", "al", "jr", "sr", "st", "inc", "ltd", "u.s", "u.k"] ABBR_REGEX = re.compile(r"\b(" + "|".join(map(re.escape, ABBR)) + r")\.", re.IGNORECASE) def _protect(text): text = text.replace("...", "⟨ELLIPSIS⟩") text = re.sub(r"(?<=\d)\.(?=\d)", "⟨DECIMAL⟩", text) text = ABBR_REGEX.sub(r"\1⟨ABBRDOT⟩", text) return text def _restore(text): return text.replace("⟨ABBRDOT⟩", ".").replace("⟨DECIMAL⟩", ".").replace("⟨ELLIPSIS⟩", "...") def split_preserving_structure(text): blocks = re.split(r"(\n+)", text) final_blocks = [] for block in blocks: if not block: continue if block.startswith("\n"): final_blocks.append(block) else: protected = _protect(block) parts = re.split(r"([.?!])(\s+)", protected) for i in range(0, len(parts), 3): sentence = parts[i] punct = parts[i + 1] if i + 1 < len(parts) else "" space = parts[i + 2] if i + 2 < len(parts) else "" if sentence.strip(): final_blocks.append(_restore(sentence + punct)) if space: final_blocks.append(space) return final_blocks # ----------------------------- # ANALYSIS # ----------------------------- @torch.inference_mode() def analyze(text): text = (text or "").strip() if not text: return "—", "—", "Please enter text...", None, "" word_count = len(text.split()) if word_count < 250: warning_msg = f"⚠️ Insufficient Text: Your input has {word_count} words. Please enter at least 250 words for accurate results." return "Too Short", "N/A", _build_error_card(warning_msg), None, "" try: tok, mod = load_desklib_model(force_redownload=False) except Exception as e: return "ERROR", "0%", _build_error_card(f"Failed to load model:
{str(e)}"), None, "" blocks = split_preserving_structure(text) pure_sents_indices = [i for i, b in enumerate(blocks) if b.strip() and not b.startswith("\n")] pure_sents = [blocks[i] for i in pure_sents_indices] if not pure_sents: return "—", "—", "No sentences detected.", None, "" windows = [] for i in range(len(pure_sents)): start = max(0, i - 1) end = min(len(pure_sents), i + 2) windows.append(" ".join(pure_sents[start:end])) batch_size = 8 probs = [] for i in range(0, len(windows), batch_size): batch = windows[i: i + batch_size] inputs = tok(batch, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device) logits = mod(input_ids=inputs["input_ids"], attention_mask=inputs.get("attention_mask")) batch_probs = torch.sigmoid(logits).detach().cpu().numpy().flatten().tolist() probs.extend(batch_probs) lengths = [len(s.split()) for s in pure_sents] total_words = sum(lengths) weighted_avg = sum(p * l for p, l in zip(probs, lengths)) / total_words if total_words > 0 else 0 # HTML Heatmap highlighted_html = "
" prob_map = {idx: probs[i] for i, idx in enumerate(pure_sents_indices)} for i, block in enumerate(blocks): if block.startswith("\n") or block.isspace(): highlighted_html += block.replace("\n", "
") continue if i in prob_map: score = prob_map[i] if score >= THRESHOLD: color, bg = "#d32f2f", "rgba(211, 47, 47, 0.12)" border = "2px solid #d32f2f" else: color, bg = "#2e7d32", "rgba(46, 125, 50, 0.08)" border = "1px solid transparent" highlighted_html += ( f"" f"{score:.0%}" f"{block}" ) else: highlighted_html += block highlighted_html += "
" label = f"{weighted_avg:.1%} AI Written" display_score = f"{weighted_avg:.2%}" df = pd.DataFrame({"Sentence": pure_sents, "AI Confidence": [f"{p:.2%}" for p in probs]}) return label, display_score, highlighted_html, df, "" def nuke_and_reload(): try: load_desklib_model(force_redownload=True) return ( "✅ **Nuked cache and reloaded model successfully.**\n\n" "- Cache wiped\n" "- Fresh download forced\n" "- Custom loader used (DesklibAIDetectionModel)\n" "- Model ready ✅" ) except Exception as e: return ( "❌ **Nuke attempted but model still failed to load.**\n\n" f"**Error:** `{str(e)}`\n\n" "If this error happens inside `load_file(model.safetensors)`, the file is truly corrupted/truncated.\n" "If it happens after that, it’s likely key mismatches (shown in logs as missing/unexpected keys)." ) # ----------------------------- # INTERFACE # ----------------------------- with gr.Blocks(theme=gr.themes.Soft(), title="AI Detector Pro") as demo: gr.Markdown("# 🕵️ AI Detector Pro") gr.Markdown(f"Model: **{MODEL_NAME}** | Highlight Threshold: **{THRESHOLD*100:.0f}%**") with gr.Row(): with gr.Column(scale=3): text_input = gr.Textbox(label="Input Text", lines=15, placeholder="Enter at least 250 words...") with gr.Row(): clear_btn = gr.Button("Clear") run_btn = gr.Button("Analyze Text", variant="primary") nuke_btn = gr.Button("💣 Nuke Model Cache", variant="stop") with gr.Column(scale=1): verdict_out = gr.Label(label="Global Verdict") score_out = gr.Label(label="Weighted Probability") status_out = gr.Markdown() with gr.Tabs(): with gr.TabItem("Visual Heatmap"): html_out = gr.HTML() with gr.TabItem("Data Breakdown"): table_out = gr.Dataframe(headers=["Sentence", "AI Confidence"], wrap=True) run_btn.click(analyze, inputs=text_input, outputs=[verdict_out, score_out, html_out, table_out, status_out]) def _clear(): return "", "—", "—", "Please enter text...", None, "" clear_btn.click(_clear, outputs=[text_input, verdict_out, score_out, html_out, table_out, status_out]) nuke_btn.click(nuke_and_reload, outputs=status_out) if __name__ == "__main__": demo.launch()