ANN_Churn / main.py
yashalchemist's picture
Add application file
531bd39
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, conint, confloat
from enum import Enum
import numpy as np
import pandas as pd
import pickle as pkl
import tensorflow as tf
from tensorflow.keras.models import load_model
# Initialize FastAPI app
app = FastAPI()
# Load the trained model
model = load_model("model.h5")
# Load the pre-trained scalers and encoders
with open("scaler.pkl", "rb") as f:
scaler = pkl.load(f)
with open("label_encoder_gender.pkl", "rb") as f:
gen_encoder = pkl.load(f)
with open("onehot_encoder_geography.pkl", "rb") as f:
geo_encoder = pkl.load(f)
# Enums for Gender and Geography
class GenderEnum(str, Enum):
Male = "Male"
Female = "Female"
class GeographyEnum(str, Enum):
France = "France"
Germany = "Germany"
Spain = "Spain"
# Pydantic model for request validation
class CustomerData(BaseModel):
CreditScore: conint(ge=350, le=850)
Gender: GenderEnum
Age: conint(ge=18, le=92)
Tenure: conint(ge=0, le=10)
Balance: confloat(ge=0)
NumOfProducts: conint(ge=1, le=4)
HasCrCard: conint(ge=0, le=1)
IsActiveMember: conint(ge=0, le=1)
EstimatedSalary: confloat(ge=0)
Geography: GeographyEnum
# API Endpoint for prediction
@app.post("/predict/")
def predict_churn(data: CustomerData):
# Encode gender
gender_encoded = gen_encoder.transform([data.Gender.value])[0]
# One-hot encode geography
geo_encoded = geo_encoder.transform([[data.Geography.value]])
geo_encoded_df = pd.DataFrame(geo_encoded, columns=geo_encoder.categories_[0])
# Prepare input data as DataFrame
input_data = {
"CreditScore": data.CreditScore,
"Gender": gender_encoded,
"Age": data.Age,
"Tenure": data.Tenure,
"Balance": data.Balance,
"NumOfProducts": data.NumOfProducts,
"HasCrCard": data.HasCrCard,
"IsActiveMember": data.IsActiveMember,
"EstimatedSalary": data.EstimatedSalary,
}
input_df = pd.DataFrame([input_data])
# Append one-hot encoded geography
input_df = pd.concat([input_df, geo_encoded_df], axis=1)
# Rename columns to match training data
input_df.rename(
columns={"France": "Geography_France", "Germany": "Geography_Germany", "Spain": "Geography_Spain"},
inplace=True,
)
# Scale the input data
input_scaled = scaler.transform(input_df)
# Make prediction
prediction = model.predict(input_scaled)
# Return result
result = "The customer is likely to churn" if prediction[0][0] > 0.5 else "The customer is not likely to churn"
return {"prediction": result, "probability": float(prediction[0][0])}