TimeSeriesForecasting / src /streamlit_app.py
EnYa32's picture
Update src/streamlit_app.py
461d207 verified
import json
import joblib
import numpy as np
import pandas as pd
import streamlit as st
from pathlib import Path
# -------------------------
# Page config
# -------------------------
st.set_page_config(
page_title='Sales Forecast (LightGBM)',
page_icon='📈',
layout='centered'
)
st.title('📈 Sales Forecast (LightGBM)')
st.write('Predict **num_sold** using a trained LightGBM model + saved encoders and preprocessing.')
BASE_DIR = Path(__file__).resolve().parent
MODEL_PATH = BASE_DIR / 'model_lgbm.pkl'
FEATURES_PATH = BASE_DIR / 'feature_names.pkl'
ENCODERS_PATH = BASE_DIR / 'encoders.pkl'
FILLMAP_PATH = BASE_DIR / 'fill_map.pkl'
META_PATH = BASE_DIR / 'meta.json'
@st.cache_resource
def load_assets():
if not MODEL_PATH.exists():
raise FileNotFoundError(f'Missing {MODEL_PATH.name} (put it next to app.py).')
if not FEATURES_PATH.exists():
raise FileNotFoundError(f'Missing {FEATURES_PATH.name} (put it next to app.py).')
if not ENCODERS_PATH.exists():
raise FileNotFoundError(f'Missing {ENCODERS_PATH.name} (put it next to app.py).')
if not FILLMAP_PATH.exists():
raise FileNotFoundError(f'Missing {FILLMAP_PATH.name} (put it next to app.py).')
model = joblib.load(MODEL_PATH)
features = joblib.load(FEATURES_PATH)
encoders = joblib.load(ENCODERS_PATH)
fill_map = joblib.load(FILLMAP_PATH)
meta = None
if META_PATH.exists():
with open(META_PATH, 'r') as f:
meta = json.load(f)
return model, features, encoders, fill_map, meta
model, FEATURES, encoders, fill_map, meta = load_assets()
with st.expander('ℹ️ Model info'):
if meta:
st.write(meta)
else:
st.write('No meta.json found.')
# -------------------------
# Helpers
# -------------------------
def make_date_features(date_value: pd.Timestamp) -> dict:
# date_value is a Timestamp
year = int(date_value.year)
month = int(date_value.month)
week = int(date_value.isocalendar().week)
dayofweek = int(date_value.dayofweek) # Monday=0
is_weekend = int(dayofweek >= 5)
dayofyear = int(date_value.dayofyear)
return {
'year': year,
'month': month,
'week': week,
'dayofweek': dayofweek,
'is_weekend': is_weekend,
'dayofyear': dayofyear
}
def safe_encode(col_name: str, value: str) -> int:
# If unseen label appears, fall back to the most frequent label (index 0) or safe default.
le = encoders.get(col_name)
if le is None:
return 0
classes = set(le.classes_.astype(str))
v = str(value)
if v in classes:
return int(le.transform([v])[0])
# fallback: use first known class
return int(le.transform([str(le.classes_[0])])[0])
# -------------------------
# UI Inputs
# -------------------------
st.subheader('🧾 Input')
date_in = st.date_input('Date', value=pd.to_datetime('2019-01-01'))
country_in = st.text_input('Country', value='Finland')
store_in = st.text_input('Store', value='KaggleMart')
product_in = st.text_input('Product', value='Kaggle Mug')
st.markdown('---')
st.subheader('⏳ Lag features')
use_manual_lags = st.checkbox('Enter lag values manually (recommended if you know them)', value=False)
default_lag_364 = float(fill_map.get('lag_364', 0.0))
default_lag_365 = float(fill_map.get('lag_365', 0.0))
default_lag_371 = float(fill_map.get('lag_371', 0.0))
if use_manual_lags:
lag_364 = st.number_input('lag_364', value=default_lag_364)
lag_365 = st.number_input('lag_365', value=default_lag_365)
lag_371 = st.number_input('lag_371', value=default_lag_371)
else:
st.write('Using default lag values (from training medians):')
st.write({
'lag_364': default_lag_364,
'lag_365': default_lag_365,
'lag_371': default_lag_371
})
lag_364, lag_365, lag_371 = default_lag_364, default_lag_365, default_lag_371
# -------------------------
# Predict
# -------------------------
if st.button('Predict'):
date_ts = pd.to_datetime(date_in)
d_feats = make_date_features(date_ts)
row = {}
row.update(d_feats)
row['lag_364'] = float(lag_364)
row['lag_365'] = float(lag_365)
row['lag_371'] = float(lag_371)
row['country'] = safe_encode('country', country_in)
row['store'] = safe_encode('store', store_in)
row['product'] = safe_encode('product', product_in)
X = pd.DataFrame([row])
# Ensure all FEATURES exist and order is correct
for c in FEATURES:
if c not in X.columns:
# numeric fallback from fill_map
X[c] = fill_map.get(c, 0.0)
X = X[FEATURES].copy()
# Fill numeric NaNs just in case
for c in X.columns:
if X[c].isna().any():
X[c] = X[c].fillna(fill_map.get(c, 0.0))
pred = model.predict(X)[0]
st.success(f'✅ Predicted num_sold: **{pred:.2f}**')
st.caption('Note: Lag features heavily influence the forecast. If you can compute real lags, the prediction will be more accurate.')
with st.expander('Show model input vector'):
st.dataframe(X)