Spaces:
Sleeping
Sleeping
| import json | |
| import joblib | |
| import numpy as np | |
| import pandas as pd | |
| import streamlit as st | |
| from pathlib import Path | |
| # ------------------------- | |
| # Page config | |
| # ------------------------- | |
| st.set_page_config( | |
| page_title='Sales Forecast (LightGBM)', | |
| page_icon='📈', | |
| layout='centered' | |
| ) | |
| st.title('📈 Sales Forecast (LightGBM)') | |
| st.write('Predict **num_sold** using a trained LightGBM model + saved encoders and preprocessing.') | |
| BASE_DIR = Path(__file__).resolve().parent | |
| MODEL_PATH = BASE_DIR / 'model_lgbm.pkl' | |
| FEATURES_PATH = BASE_DIR / 'feature_names.pkl' | |
| ENCODERS_PATH = BASE_DIR / 'encoders.pkl' | |
| FILLMAP_PATH = BASE_DIR / 'fill_map.pkl' | |
| META_PATH = BASE_DIR / 'meta.json' | |
| def load_assets(): | |
| if not MODEL_PATH.exists(): | |
| raise FileNotFoundError(f'Missing {MODEL_PATH.name} (put it next to app.py).') | |
| if not FEATURES_PATH.exists(): | |
| raise FileNotFoundError(f'Missing {FEATURES_PATH.name} (put it next to app.py).') | |
| if not ENCODERS_PATH.exists(): | |
| raise FileNotFoundError(f'Missing {ENCODERS_PATH.name} (put it next to app.py).') | |
| if not FILLMAP_PATH.exists(): | |
| raise FileNotFoundError(f'Missing {FILLMAP_PATH.name} (put it next to app.py).') | |
| model = joblib.load(MODEL_PATH) | |
| features = joblib.load(FEATURES_PATH) | |
| encoders = joblib.load(ENCODERS_PATH) | |
| fill_map = joblib.load(FILLMAP_PATH) | |
| meta = None | |
| if META_PATH.exists(): | |
| with open(META_PATH, 'r') as f: | |
| meta = json.load(f) | |
| return model, features, encoders, fill_map, meta | |
| model, FEATURES, encoders, fill_map, meta = load_assets() | |
| with st.expander('ℹ️ Model info'): | |
| if meta: | |
| st.write(meta) | |
| else: | |
| st.write('No meta.json found.') | |
| # ------------------------- | |
| # Helpers | |
| # ------------------------- | |
| def make_date_features(date_value: pd.Timestamp) -> dict: | |
| # date_value is a Timestamp | |
| year = int(date_value.year) | |
| month = int(date_value.month) | |
| week = int(date_value.isocalendar().week) | |
| dayofweek = int(date_value.dayofweek) # Monday=0 | |
| is_weekend = int(dayofweek >= 5) | |
| dayofyear = int(date_value.dayofyear) | |
| return { | |
| 'year': year, | |
| 'month': month, | |
| 'week': week, | |
| 'dayofweek': dayofweek, | |
| 'is_weekend': is_weekend, | |
| 'dayofyear': dayofyear | |
| } | |
| def safe_encode(col_name: str, value: str) -> int: | |
| # If unseen label appears, fall back to the most frequent label (index 0) or safe default. | |
| le = encoders.get(col_name) | |
| if le is None: | |
| return 0 | |
| classes = set(le.classes_.astype(str)) | |
| v = str(value) | |
| if v in classes: | |
| return int(le.transform([v])[0]) | |
| # fallback: use first known class | |
| return int(le.transform([str(le.classes_[0])])[0]) | |
| # ------------------------- | |
| # UI Inputs | |
| # ------------------------- | |
| st.subheader('🧾 Input') | |
| date_in = st.date_input('Date', value=pd.to_datetime('2019-01-01')) | |
| country_in = st.text_input('Country', value='Finland') | |
| store_in = st.text_input('Store', value='KaggleMart') | |
| product_in = st.text_input('Product', value='Kaggle Mug') | |
| st.markdown('---') | |
| st.subheader('⏳ Lag features') | |
| use_manual_lags = st.checkbox('Enter lag values manually (recommended if you know them)', value=False) | |
| default_lag_364 = float(fill_map.get('lag_364', 0.0)) | |
| default_lag_365 = float(fill_map.get('lag_365', 0.0)) | |
| default_lag_371 = float(fill_map.get('lag_371', 0.0)) | |
| if use_manual_lags: | |
| lag_364 = st.number_input('lag_364', value=default_lag_364) | |
| lag_365 = st.number_input('lag_365', value=default_lag_365) | |
| lag_371 = st.number_input('lag_371', value=default_lag_371) | |
| else: | |
| st.write('Using default lag values (from training medians):') | |
| st.write({ | |
| 'lag_364': default_lag_364, | |
| 'lag_365': default_lag_365, | |
| 'lag_371': default_lag_371 | |
| }) | |
| lag_364, lag_365, lag_371 = default_lag_364, default_lag_365, default_lag_371 | |
| # ------------------------- | |
| # Predict | |
| # ------------------------- | |
| if st.button('Predict'): | |
| date_ts = pd.to_datetime(date_in) | |
| d_feats = make_date_features(date_ts) | |
| row = {} | |
| row.update(d_feats) | |
| row['lag_364'] = float(lag_364) | |
| row['lag_365'] = float(lag_365) | |
| row['lag_371'] = float(lag_371) | |
| row['country'] = safe_encode('country', country_in) | |
| row['store'] = safe_encode('store', store_in) | |
| row['product'] = safe_encode('product', product_in) | |
| X = pd.DataFrame([row]) | |
| # Ensure all FEATURES exist and order is correct | |
| for c in FEATURES: | |
| if c not in X.columns: | |
| # numeric fallback from fill_map | |
| X[c] = fill_map.get(c, 0.0) | |
| X = X[FEATURES].copy() | |
| # Fill numeric NaNs just in case | |
| for c in X.columns: | |
| if X[c].isna().any(): | |
| X[c] = X[c].fillna(fill_map.get(c, 0.0)) | |
| pred = model.predict(X)[0] | |
| st.success(f'✅ Predicted num_sold: **{pred:.2f}**') | |
| st.caption('Note: Lag features heavily influence the forecast. If you can compute real lags, the prediction will be more accurate.') | |
| with st.expander('Show model input vector'): | |
| st.dataframe(X) |