| import os |
| import time |
| from threading import Thread |
| import gradio as gr |
| import spaces |
| from PIL import Image |
| import torch |
| from transformers import ( |
| AutoProcessor, |
| AutoModelForImageTextToText, |
| Qwen2_5_VLForConditionalGeneration, |
| TextIteratorStreamer, |
| ) |
| MODEL_PATHS = { |
| "Model 3 (structured handwritting)": ( |
| "Emeritus-21/Finetuned-full-HTR-model", |
| AutoModelForImageTextToText, |
| ), |
| } |
|
|
| MAX_NEW_TOKENS_DEFAULT = 512 |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| |
| |
| |
| _loaded_processors = {} |
| _loaded_models = {} |
|
|
| print("๐ Preloading models into GPU/CPU memory...") |
|
|
| for name, (repo_id, cls) in MODEL_PATHS.items(): |
| try: |
| print(f"Loading {name} ...") |
| processor = AutoProcessor.from_pretrained(repo_id, trust_remote_code=True) |
| model = cls.from_pretrained( |
| repo_id, |
| trust_remote_code=True, |
| torch_dtype=torch.float16 |
| ).to(device).eval() |
| _loaded_processors[name] = processor |
| _loaded_models[name] = model |
| print(f"โ
{name} ready.") |
| except Exception as e: |
| print(f"โ ๏ธ Failed to load {name}: {e}") |
|
|
| |
| |
| |
| |
| def warmup(): |
| try: |
| default_model_choice = list(MODEL_PATHS.keys())[0] |
| processor = _loaded_processors[default_model_choice] |
| model = _loaded_models[default_model_choice] |
|
|
| messages = [{"role": "user", "content": [{"type": "text", "text": "Warmup."}]}] |
| chat_prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
| inputs = processor(text=[chat_prompt], images=None, return_tensors="pt").to(device) |
|
|
| with torch.inference_mode(): |
| _ = model.generate(**inputs, max_new_tokens=1) |
|
|
| return f"GPU warm and {default_model_choice} ready." |
| except Exception as e: |
| return f"Warmup skipped: {e}" |
|
|
| |
| |
| |
| |
| def ocr_image(image: Image.Image, model_choice: str, query: str = None, |
| max_new_tokens: int = MAX_NEW_TOKENS_DEFAULT, |
| temperature: float = 0.1, top_p: float = 1.0, top_k: int = 0, repetition_penalty: float = 1.0): |
|
|
| if image is None: |
| yield "Please upload an image." |
| return |
|
|
| if model_choice not in _loaded_models: |
| yield f"Invalid model: {model_choice}" |
| return |
|
|
| processor = _loaded_processors[model_choice] |
| model = _loaded_models[model_choice] |
|
|
| if query and query.strip(): |
| prompt = query.strip() |
| else: |
| prompt = ( |
| "You are a professional Handwritten OCR system.\n" |
| "TASK: Read the handwritten image and transcribe the text EXACTLY as written.\n" |
| "- Preserve original structure and line breaks.\n" |
| "- Keep spacing, bullet points, numbering, and indentation.\n" |
| "- Render tables as Markdown tables if present.\n" |
| "- Do NOT autocorrect spelling or grammar.\n" |
| "- Do NOT merge lines.\n" |
| "Return RAW transcription only." |
| ) |
|
|
| messages = [{ |
| "role": "user", |
| "content": [ |
| {"type": "image", "image": image}, |
| {"type": "text", "text": prompt} |
| ] |
| }] |
|
|
| chat_prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
| inputs = processor(text=[chat_prompt], images=[image], return_tensors="pt").to(device) |
|
|
| streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True) |
|
|
| generation_kwargs = dict( |
| **inputs, |
| streamer=streamer, |
| max_new_tokens=max_new_tokens, |
| do_sample=False, |
| temperature=temperature, |
| top_p=top_p, |
| top_k=top_k, |
| repetition_penalty=repetition_penalty |
| ) |
|
|
| thread = Thread(target=model.generate, kwargs=generation_kwargs) |
| thread.start() |
|
|
| buffer = "" |
| for new_text in streamer: |
| new_text = new_text.replace("<|im_end|>", "") |
| buffer += new_text |
| time.sleep(0.01) |
| yield buffer |
|
|
| |
| |
| |
| with gr.Blocks() as demo: |
| gr.Markdown("## wilson Handwritten OCR ") |
|
|
| model_choice = gr.Radio( |
| choices=list(MODEL_PATHS.keys()), |
| value=list(MODEL_PATHS.keys())[0], |
| label="Select OCR Model" |
| ) |
|
|
| with gr.Tab("๐ผ Image Inference"): |
| query_input = gr.Textbox(label="Custom Prompt (optional)", placeholder="Leave empty for RAW structured output") |
| image_input = gr.Image(type="pil", label="Upload Handwritten Image") |
|
|
| with gr.Accordion("โ๏ธ Advanced Options", open=False): |
| max_new_tokens = gr.Slider(1, 2048, value=MAX_NEW_TOKENS_DEFAULT, step=1, label="Max new tokens") |
| temperature = gr.Slider(0.1, 2.0, value=0.1, step=0.05, label="Temperature") |
| top_p = gr.Slider(0.05, 1.0, value=1.0, step=0.05, label="Top-p (nucleus)") |
| top_k = gr.Slider(0, 1000, value=0, step=1, label="Top-k") |
| repetition_penalty = gr.Slider(0.8, 2.0, value=1.0, step=0.05, label="Repetition penalty") |
|
|
| with gr.Row(): |
| extract_btn = gr.Button("๐ค Extract RAW Text", variant="primary") |
| clear_btn = gr.Button("๐งน Clear") |
|
|
| raw_output = gr.Textbox(label="๐ RAW Structured Output (exact as written)", lines=18, show_copy_button=True) |
|
|
| extract_btn.click( |
| fn=ocr_image, |
| inputs=[image_input, model_choice, query_input, max_new_tokens, temperature, top_p, top_k, repetition_penalty], |
| outputs=[raw_output], |
| api_name="ocr_image" |
| ) |
|
|
| clear_btn.click( |
| fn=lambda: ("", None, ""), |
| outputs=[raw_output, image_input, query_input] |
| ) |
|
|
| if __name__ == "__main__": |
| demo.queue(max_size=50).launch(share=True, ssr_mode=False, show_error=True) |
|
|