| import gradio as gr |
| import pandas as pd |
| import joblib |
| import os |
| from sklearn.ensemble import RandomForestClassifier |
| from sklearn.model_selection import train_test_split |
|
|
| MODEL_PATH = "rf_model.pkl" |
| DATA_URL = "https://archive.ics.uci.edu/ml/machine-learning-databases/wine-quality/winequality-white.csv" |
|
|
| |
| |
| |
|
|
| def train_model(): |
| print("Downloading white wine dataset...") |
| df = pd.read_csv(DATA_URL, sep=';') |
|
|
| feature_names = [ |
| 'fixed acidity', 'volatile acidity', 'citric acid', 'residual sugar', |
| 'chlorides', 'free sulfur dioxide', 'total sulfur dioxide', 'density', |
| 'pH', 'sulphates', 'alcohol' |
| ] |
|
|
| X = df[feature_names] |
| y = df['quality'] |
|
|
| X_train, X_test, y_train, y_test = train_test_split( |
| X, y, test_size=0.2, random_state=42 |
| ) |
|
|
| print("Training Random Forest model...") |
| model = RandomForestClassifier( |
| n_estimators=300, |
| max_depth=12, |
| random_state=42 |
| ) |
| model.fit(X_train, y_train) |
|
|
| joblib.dump(model, MODEL_PATH) |
| print("Model saved as rf_model.pkl") |
| return model |
|
|
|
|
| |
| if os.path.exists(MODEL_PATH): |
| print("Loading existing model...") |
| model = joblib.load(MODEL_PATH) |
| else: |
| model = train_model() |
|
|
|
|
| |
| |
| |
|
|
| feature_names = [ |
| 'fixed acidity', 'volatile acidity', 'citric acid', 'residual sugar', |
| 'chlorides', 'free sulfur dioxide', 'total sulfur dioxide', 'density', |
| 'pH', 'sulphates', 'alcohol' |
| ] |
|
|
| def predict_quality(*inputs): |
| df = pd.DataFrame([inputs], columns=feature_names) |
| prediction = model.predict(df)[0] |
| return f"Predicted Wine Quality: {prediction}" |
|
|
|
|
| |
| |
| |
|
|
| inputs_ui = [gr.Number(label=name) for name in feature_names] |
|
|
| demo = gr.Interface( |
| fn=predict_quality, |
| inputs=inputs_ui, |
| outputs=gr.Textbox(label="Prediction"), |
| title="🍾 White Wine Quality Predictor (Trains on HF Space)", |
| description="Random Forest model trained on the UCI White Wine Quality dataset." |
| ) |
|
|
| demo.launch() |
|
|