| import torch |
| from PIL import Image |
| import io |
| import base64 |
| from transformers import GlmOcrProcessor, GlmOcrForConditionalGeneration |
| from fastapi import FastAPI, Request |
|
|
| app = FastAPI() |
| handler = EndpointHandler("/repository") |
|
|
| class EndpointHandler(): |
| def __init__(self, path=""): |
| |
| self.processor = GlmOcrProcessor.from_pretrained(path) |
| self.model = GlmOcrForConditionalGeneration.from_pretrained( |
| path, |
| device_map="auto", |
| torch_dtype=torch.bfloat16 |
| ) |
| self.model.eval() |
|
|
| def __call__(self, data): |
| |
| inputs_data = data.pop("inputs", data) |
| image_bytes = base64.b64decode(inputs_data) |
| image = Image.open(io.BytesIO(image_bytes)).convert("RGB") |
| |
| |
| prompt = "Extract receipt items into JSON: [{date, vendor, description, qty, price, total}]" |
| |
| |
| inputs = self.processor(images=image, text=prompt, return_tensors="pt").to(self.model.device) |
| |
| with torch.no_grad(): |
| generated_ids = self.model.generate(**inputs, max_new_tokens=1024) |
| |
| |
| result = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0] |
| |
| return [{"generated_text": result}] |
|
|
|
|
| @app.post("/") |
| @app.post("/v1/chat/completions") |
| async def handle(request: Request): |
| data = await request.json() |
| return handler(data) |
|
|
| @app.get("/health") |
| async def health(): |
| return {"status": "ok"} |