|
|
|
|
| import torch
|
| from io import BytesIO
|
| from PIL import Image
|
| from torchvision import transforms
|
| from TumorModel import TumorClassification
|
|
|
|
|
| _transform = transforms.Compose([
|
| transforms.Grayscale(),
|
| transforms.Resize((224, 224)),
|
| transforms.ToTensor(),
|
| transforms.Normalize([0.5], [0.5]),
|
| ])
|
|
|
|
|
| _model = TumorClassification()
|
| _model.load_state_dict(torch.load("BTD_model.pth", map_location="cpu"))
|
| _model.eval()
|
|
|
| def inference(image_bytes):
|
| """
|
| Hugging Face will pass the raw image bytes here.
|
| Return {"label": <one of glioma, meningioma, notumor, pituitary>}.
|
| """
|
| img = Image.open(BytesIO(image_bytes)).convert("RGB")
|
| x = _transform(img).unsqueeze(0)
|
| with torch.no_grad():
|
| idx = torch.argmax(_model(x), dim=1).item()
|
| labels = ["glioma", "meningioma", "notumor", "pituitary"]
|
| return {"label": labels[idx]}
|
|
|