|
|
| from fastapi import FastAPI, File, UploadFile
|
| from PIL import Image
|
| import io, torch
|
| from collections import Counter
|
|
|
| from models import ModelA, ModelB, ModelC, transform_small, transform_large
|
|
|
|
|
| app = FastAPI()
|
|
|
|
|
| device = torch.device('cpu')
|
| modelA = ModelA();
|
| modelA.load_state_dict(torch.load('modelA.pth', map_location=device,weights_only=True))
|
| modelA.eval()
|
|
|
| modelB = ModelB()
|
| modelB.load_state_dict(torch.load('modelB.pth', map_location=device,weights_only=True))
|
| modelB.eval()
|
| modelC = ModelC()
|
| modelC.load_state_dict(torch.load('modelC.pth', map_location=device,weights_only=True))
|
| modelC.eval()
|
|
|
| @app.post("/predict/")
|
| async def predict(file: UploadFile = File(...)):
|
|
|
| data = await file.read()
|
| img = Image.open(io.BytesIO(data)).convert('RGB')
|
|
|
|
|
| t_small = transform_small(img).unsqueeze(0)
|
| t_large = transform_large(img).unsqueeze(0)
|
|
|
|
|
| votes = []
|
| with torch.no_grad():
|
| for model, inp in [(modelA, t_small), (modelB, t_small), (modelC, t_large)]:
|
| out = model(inp)
|
| _, pred = out.max(1)
|
| votes.append(int(pred.item()))
|
|
|
|
|
| vote_count = Counter(votes)
|
| final_label = vote_count.most_common(1)[0][0]
|
| confidence = vote_count[final_label] / len(votes)
|
|
|
| return {
|
| "prediction": "Real" if final_label == 1 else "Fake",
|
| "confidence": f"{confidence*100:.1f}%"
|
| }
|
|
|