Create app.py
Browse files
app.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
from sklearn.datasets import load_iris
|
| 3 |
+
from sklearn.tree import DecisionTreeClassifier, plot_tree
|
| 4 |
+
from sklearn.model_selection import train_test_split
|
| 5 |
+
import matplotlib.pyplot as plt
|
| 6 |
+
import pandas as pd
|
| 7 |
+
|
| 8 |
+
# Load the Iris dataset
|
| 9 |
+
iris = load_iris()
|
| 10 |
+
X = iris.data
|
| 11 |
+
y = iris.target
|
| 12 |
+
feature_names = iris.feature_names
|
| 13 |
+
target_names = iris.target_names
|
| 14 |
+
|
| 15 |
+
# Train a Decision Tree model
|
| 16 |
+
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)
|
| 17 |
+
clf = DecisionTreeClassifier(max_depth=3, random_state=42)
|
| 18 |
+
clf.fit(X_train, y_train)
|
| 19 |
+
|
| 20 |
+
# Streamlit interface
|
| 21 |
+
st.title("🌸 Iris Flower Predictor with Decision Tree")
|
| 22 |
+
st.write("This app uses a Decision Tree Classifier to predict the type of Iris flower.")
|
| 23 |
+
|
| 24 |
+
# Sidebar for user input
|
| 25 |
+
st.sidebar.header("Input Features")
|
| 26 |
+
sepal_length = st.sidebar.slider('Sepal length (cm)', 4.0, 8.0, 5.1)
|
| 27 |
+
sepal_width = st.sidebar.slider('Sepal width (cm)', 2.0, 4.5, 3.5)
|
| 28 |
+
petal_length = st.sidebar.slider('Petal length (cm)', 1.0, 7.0, 1.4)
|
| 29 |
+
petal_width = st.sidebar.slider('Petal width (cm)', 0.1, 2.5, 0.2)
|
| 30 |
+
|
| 31 |
+
# Make prediction
|
| 32 |
+
input_data = [[sepal_length, sepal_width, petal_length, petal_width]]
|
| 33 |
+
prediction = clf.predict(input_data)[0]
|
| 34 |
+
predicted_class = target_names[prediction]
|
| 35 |
+
|
| 36 |
+
st.subheader("🌼 Predicted Iris Species")
|
| 37 |
+
st.success(f"The model predicts: **{predicted_class}**")
|
| 38 |
+
|
| 39 |
+
# Show decision tree
|
| 40 |
+
st.subheader("🧠 Decision Tree Visualization")
|
| 41 |
+
fig, ax = plt.subplots(figsize=(12, 6))
|
| 42 |
+
plot_tree(clf, feature_names=feature_names, class_names=target_names, filled=True, rounded=True)
|
| 43 |
+
st.pyplot(fig)
|
| 44 |
+
|
| 45 |
+
# Show model accuracy
|
| 46 |
+
accuracy = clf.score(X_test, y_test)
|
| 47 |
+
st.subheader("📈 Model Accuracy")
|
| 48 |
+
st.write(f"The model accuracy on the test set is **{accuracy:.2f}**")
|