import os import tempfile import uuid import warnings import re import gradio as gr import requests from PIL import Image from pathlib import Path from main import WorksheetSolver warnings.filterwarnings("ignore") def get_gap_model() -> str: download = False os.makedirs("./model", exist_ok=True) folder_path = Path("./model") model_folder_names = [p.name for p in folder_path.iterdir() if p.is_dir()] if model_folder_names: latest_version = sorted(model_folder_names, key=lambda s: list(map(int, s.lstrip("v").split("."))), reverse=True)[0] model_path = folder_path / latest_version / "gap_detection_model.pt" if not model_path.exists(): download = True else: download = True release_response = requests.get(RELEASES_URL) if release_response.status_code == 200: pattern = re.compile(r"]*>(v\d+\.\d+\.\d+)") versions = pattern.findall(release_response.text) if not versions: raise Exception("Could not determine the latest model version from GitHub releases.") else: raise Exception(f"Failed to fetch releases from GitHub: {release_response.status_code}") for version in versions: GAP_MODEL_URL = f"https://github.com/Hawk3388/solver/releases/download/{version}/gap_detection_model.pt" if not url_exists(GAP_MODEL_URL): continue if download: gd_model_path = str(folder_path / version / "gap_detection_model.pt") with requests.get(GAP_MODEL_URL, stream=True, timeout=60) as response: with open(gd_model_path, "wb") as model_file: for chunk in response.iter_content(chunk_size=8192): if chunk: model_file.write(chunk) break else: compare_versions = sorted([latest_version, version], key=lambda s: list(map(int, s.lstrip("v").split("."))), reverse=True) newer_version = compare_versions[0] if newer_version != latest_version: gd_model_path = str(folder_path / newer_version / "gap_detection_model.pt") with requests.get(GAP_MODEL_URL, stream=True, timeout=60) as response: with open(gd_model_path, "wb") as model_file: for chunk in response.iter_content(chunk_size=8192): if chunk: model_file.write(chunk) break else: gd_model_path = str(model_path) return gd_model_path def url_exists(url: str, timeout: float = 5.0) -> bool: try: r = requests.head(url, allow_redirects=True, timeout=timeout) return (200 <= r.status_code < 400) except requests.RequestException as e: return False def _is_allowed_image(filename: str) -> bool: return "." in filename and filename.rsplit(".", 1)[1].lower() in ALLOWED_EXTENSIONS def solve_worksheet(image_path: str): if not image_path: raise gr.Error("Please upload an image first.") if not _is_allowed_image(image_path): raise gr.Error("Please upload a valid image file (PNG, JPG, JPEG, WEBP, BMP).") with tempfile.TemporaryDirectory() as tmp_dir: unique_id = uuid.uuid4().hex input_path = os.path.join(tmp_dir, f"{unique_id}.png") output_path = os.path.join(tmp_dir, f"{unique_id}_solved.png") try: Image.open(image_path).convert("RGB").save(input_path) solver = WorksheetSolver( input_path, gap_detection_model_path=MODEL_PATH, llm_model_name="gemini-3-flash-preview", think=True, local=False, thinking_budget=2048, debug=False, experimental=False, ) gaps, detected_image = solver.detect_gaps() if not gaps: raise gr.Error("No gaps were detected. Please try a clearer worksheet image.") marked_image = solver.mark_gaps(detected_image, gaps) solutions = solver.solve_all_gaps(marked_image) if not solutions: raise gr.Error("The AI could not find any solutions.") solver.fill_gaps_in_image(input_path, solutions, output_path=output_path) solved_image = Image.open(output_path).copy() return solved_image except Exception as error: raise gr.Error(f"Processing error: {error}") from error def build_app() -> gr.Blocks: with gr.Blocks(title="Worksheet Solver", css=""" .app-shell {max-width: 1200px; margin: 0 auto;} .hero {text-align: center; margin: 14px 0 8px;} .hero h1 {font-size: 2rem; margin-bottom: 6px;} .hero p {opacity: 0.85;} """) as demo: gr.HTML( """

Worksheet Solver

Upload a worksheet image and generate the solved version.

""" ) with gr.Row(elem_classes=["app-shell"]): with gr.Column(scale=1): image_input = gr.Image( type="filepath", label="Worksheet Image", sources=["upload"], ) solve_button = gr.Button("Solve", variant="primary") with gr.Column(scale=1): image_output = gr.Image(type="pil", label="Solved Worksheet") solve_button.click( fn=solve_worksheet, inputs=image_input, outputs=image_output, ) return demo ALLOWED_EXTENSIONS = {"png", "jpg", "jpeg", "webp", "bmp"} RELEASES_URL = "https://github.com/Hawk3388/solver/releases" MODEL_PATH = get_gap_model() demo = build_app() if __name__ == "__main__": demo.queue().launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", "7860")), share=True)