| | import base64 |
| | from fastapi import FastAPI, File, UploadFile |
| | from pydantic import BaseModel |
| | import numpy as np |
| | from PIL import Image |
| | import gradio as gr |
| | import tensorflow as tf |
| |
|
| | |
| | modelCNN = tf.keras.models.load_model('modelCNN.h5') |
| |
|
| | def predict(sketch): |
| | try: |
| | if 'layers' in sketch: |
| | |
| | sketch_image = np.array(sketch['layers'][0]) * 255 |
| | sketch_image = Image.fromarray(sketch_image.astype('uint8')).convert('L') |
| |
|
| | |
| | sketch_image = sketch_image.resize((28, 28)) |
| |
|
| | |
| | img_array = np.array(sketch_image).reshape(1, 28, 28, 1) / 255.0 |
| |
|
| | |
| | prediction = modelCNN.predict(img_array)[0] |
| | predicted_digit = np.argmax(prediction) |
| | return str(predicted_digit) |
| | else: |
| | |
| | return "No sketch data found" |
| | except Exception as e: |
| | |
| | print("Error:", e) |
| | |
| | return "An error occurred" |
| |
|
| | |
| | gr.Interface(fn=predict, inputs="sketchpad", outputs="label").launch() |
| |
|