GLM-OCR / handler.py
defford's picture
Update handler.py
f06d29f verified
import torch
from PIL import Image
import io
import base64
from transformers import GlmOcrProcessor, GlmOcrForConditionalGeneration
from fastapi import FastAPI, Request
app = FastAPI()
handler = EndpointHandler("/repository")
class EndpointHandler():
def __init__(self, path=""):
# Native 5.1.0 classes specifically for GLM-OCR
self.processor = GlmOcrProcessor.from_pretrained(path)
self.model = GlmOcrForConditionalGeneration.from_pretrained(
path,
device_map="auto",
torch_dtype=torch.bfloat16
)
self.model.eval()
def __call__(self, data):
# Extract base64 image from the 'inputs' field sent by Google Sheets
inputs_data = data.pop("inputs", data)
image_bytes = base64.b64decode(inputs_data)
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
# Bookkeeping prompt - Native formatting
prompt = "Extract receipt items into JSON: [{date, vendor, description, qty, price, total}]"
# New 5.1.0 process workflow
inputs = self.processor(images=image, text=prompt, return_tensors="pt").to(self.model.device)
with torch.no_grad():
generated_ids = self.model.generate(**inputs, max_new_tokens=1024)
# Decode results
result = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
return [{"generated_text": result}]
@app.post("/")
@app.post("/v1/chat/completions")
async def handle(request: Request):
data = await request.json()
return handler(data)
@app.get("/health")
async def health():
return {"status": "ok"}