api / src /forecasting /chronos_forecaster.py
Eli Safra
Deploy SolarWine API (FastAPI + Docker, port 7860)
938949f
"""
ChronosForecaster: Day-ahead photosynthesis (A) forecasting using Amazon
Chronos-2 foundation model with native covariate support and optional
LoRA fine-tuning.
Improvement history:
v1: Broken — daytime-only rows with hidden gaps → MAE ~8.5
v2: Regular 15-min grid + predict_df + daytime eval → MAE ~1.75 (20w)
v3: + On-site sensor covariates (PAR, VPD, T_leaf, CO2)
+ 14-day context (captures ~2 weeks of diurnal pattern)
+ LoRA fine-tuning (1000 steps, lr=1e-4)
+ Configurable covariate modes for ablation
→ MAE 1.37 (May), 3.0-3.4 (Jun-Sep), overall beats ML baseline (2.7)
v4: Revisited input features: added engineered time (hour_sin/cos, doy_sin/cos) and
stress_risk_ims (VPD from IMS T+RH) in load_data; tried extended IMS (tdmax/tdmin).
Ablation on current data: best zero-shot = sensor (MAE ~3.86) or all (MAE ~3.91, R² 0.52).
Time/stress as covariates slightly hurt; kept 4-col IMS + sensor for \"all\".
"""
from __future__ import annotations
from pathlib import Path
from typing import Optional
import numpy as np
import pandas as pd
import torch
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from config.settings import (
PROCESSED_DIR, IMS_CACHE_DIR, OUTPUTS_DIR, GROWING_SEASON_MONTHS,
)
from src.time_features import add_cyclical_time_features
# ---------------------------------------------------------------------------
# Covariate definitions
# ---------------------------------------------------------------------------
# IMS station 43 weather (available as day-ahead forecasts in production)
# tdmax_c, tdmin_c available in data; ablation showed 4-col IMS best for this dataset
IMS_COVARIATE_COLS = [
"ghi_w_m2", "air_temperature_c", "rh_percent", "wind_speed_ms",
]
# On-site Seymour sensors (past-only: not available as forecasts)
SENSOR_COVARIATE_COLS = [
"PAR_site", "VPD_site", "T_leaf_site", "CO2_site",
]
# Engineered time features (deterministic from timestamp; available for future)
TIME_COVARIATE_COLS = ["hour_sin", "hour_cos", "doy_sin", "doy_cos"]
# Stress risk from IMS-derived VPD (past + future; VPD_ims from T + RH)
STRESS_COVARIATE_COL = "stress_risk_ims"
# Column mapping from raw sensor CSV → clean names
_SENSOR_COL_MAP = {
"Air1_PAR_ref": "PAR_site",
"Air1_VPD_ref": "VPD_site",
"Air1_leafTemperature_ref": "T_leaf_site",
"Air1_CO2_ref": "CO2_site",
}
FREQ = "15min"
STEPS_PER_DAY = 96 # 24h / 15min
# VPD from IMS T and RH (Buck formula, kPa) for stress_risk_ims
def _vpd_from_ims_kpa(T_c: np.ndarray, rh_percent: np.ndarray) -> np.ndarray:
"""Saturation vapour pressure (kPa) then VPD = esat * (1 - RH/100)."""
esat = 0.611 * np.exp(17.27 * T_c / (T_c + 237.3))
return esat * (1.0 - np.clip(rh_percent, 0, 100) / 100.0)
# Covariate mode presets
# "all" = extended IMS (incl. tdmax/tdmin) + sensor; time/stress available in data for optional use
COVARIATE_MODES = {
"none": {"past": [], "future": []},
"ims": {"past": IMS_COVARIATE_COLS, "future": IMS_COVARIATE_COLS},
"sensor": {"past": SENSOR_COVARIATE_COLS, "future": []},
"all": {
"past": IMS_COVARIATE_COLS + SENSOR_COVARIATE_COLS,
"future": IMS_COVARIATE_COLS,
},
}
class ChronosForecaster:
"""Day-ahead A forecaster using Chronos-2 with configurable covariates."""
def __init__(
self,
model_name: str = "amazon/chronos-2",
device: str = "mps",
context_days: int = 14,
):
self.model_name = model_name
self.device = device
self.context_steps = context_days * STEPS_PER_DAY
self._pipeline = None
@property
def pipeline(self):
"""Lazy-load Chronos-2 pipeline on first use."""
if self._pipeline is None:
from chronos import Chronos2Pipeline
self._pipeline = Chronos2Pipeline.from_pretrained(
self.model_name,
device_map=self.device,
dtype=torch.float32,
)
return self._pipeline
@pipeline.setter
def pipeline(self, value):
"""Allow setting pipeline (e.g. after fine-tuning)."""
self._pipeline = value
# ------------------------------------------------------------------
# Data loading and resampling
# ------------------------------------------------------------------
@staticmethod
def load_data(
labels_path: Optional[Path] = None,
ims_path: Optional[Path] = None,
sensor_path: Optional[Path] = None,
growing_season_only: bool = True,
) -> pd.DataFrame:
"""Load labels + IMS + on-site sensors, merge, resample to regular grid.
Growing-season-only mode (default) drops Oct-Apr dormancy months,
concatenating seasons into a continuous series with season boundaries
marked by a 'season' column.
"""
from config.settings import DATA_DIR, SEYMOUR_DIR
labels_path = labels_path or PROCESSED_DIR / "stage1_labels.csv"
ims_path = ims_path or IMS_CACHE_DIR / "ims_merged_15min.csv"
sensor_path = sensor_path or SEYMOUR_DIR / "sensors_wide.csv"
# --- Labels ---
labels = pd.read_csv(labels_path, parse_dates=["time"])
labels.rename(columns={"time": "timestamp_utc"}, inplace=True)
labels["timestamp_utc"] = pd.to_datetime(labels["timestamp_utc"], utc=True)
# --- IMS ---
ims = pd.read_csv(ims_path, parse_dates=["timestamp_utc"])
ims["timestamp_utc"] = pd.to_datetime(ims["timestamp_utc"], utc=True)
# --- On-site sensors ---
raw_cols = ["time"] + list(_SENSOR_COL_MAP.keys())
sensors = pd.read_csv(sensor_path, usecols=raw_cols, parse_dates=["time"])
sensors.rename(columns={"time": "timestamp_utc", **_SENSOR_COL_MAP}, inplace=True)
sensors["timestamp_utc"] = pd.to_datetime(sensors["timestamp_utc"], utc=True)
# --- Merge ---
merged = labels.merge(ims, on="timestamp_utc", how="inner")
merged = merged.merge(sensors, on="timestamp_utc", how="left")
merged.sort_values("timestamp_utc", inplace=True)
merged.set_index("timestamp_utc", inplace=True)
# --- Resample to regular 15-min grid ---
full_idx = pd.date_range(
merged.index.min(), merged.index.max(), freq=FREQ, tz="UTC",
)
resampled = merged.reindex(full_idx)
resampled.index.name = "timestamp_utc"
# Fill A=0 overnight, interpolate covariates
resampled["A"] = resampled["A"].fillna(0.0)
all_cov_cols = [
c for c in IMS_COVARIATE_COLS + SENSOR_COVARIATE_COLS
if c in resampled.columns
]
for col in all_cov_cols:
resampled[col] = (
resampled[col].interpolate(method="time").ffill().bfill()
)
if col in ("ghi_w_m2", "PAR_site"):
resampled[col] = resampled[col].clip(lower=0)
# Engineered time covariates (deterministic; available for future)
resampled = add_cyclical_time_features(resampled, index_is_timestamp=True)
# Stress risk from IMS VPD (past + future; 0–1 scale, clip VPD at 6 kPa)
if "air_temperature_c" in resampled.columns and "rh_percent" in resampled.columns:
vpd_ims = _vpd_from_ims_kpa(
resampled["air_temperature_c"].values,
resampled["rh_percent"].values,
)
resampled[STRESS_COVARIATE_COL] = np.clip(vpd_ims / 6.0, 0.0, 1.0)
resampled.reset_index(inplace=True)
# --- Growing-season filter ---
if growing_season_only:
resampled["month"] = resampled["timestamp_utc"].dt.month
resampled = resampled[
resampled["month"].isin(GROWING_SEASON_MONTHS)
].copy()
resampled.drop(columns=["month"], inplace=True)
resampled.reset_index(drop=True, inplace=True)
# Add season column (year of growing season)
resampled["season"] = resampled["timestamp_utc"].dt.year
return resampled
@staticmethod
def load_sparse_data(
labels_path: Optional[Path] = None,
ims_path: Optional[Path] = None,
) -> pd.DataFrame:
"""Load original daytime-only merged data (no resampling).
Used to identify daytime timestamps for evaluation masking.
"""
labels_path = labels_path or PROCESSED_DIR / "stage1_labels.csv"
ims_path = ims_path or IMS_CACHE_DIR / "ims_merged_15min.csv"
labels = pd.read_csv(labels_path, parse_dates=["time"])
labels.rename(columns={"time": "timestamp_utc"}, inplace=True)
labels["timestamp_utc"] = pd.to_datetime(labels["timestamp_utc"], utc=True)
ims = pd.read_csv(ims_path, parse_dates=["timestamp_utc"])
ims["timestamp_utc"] = pd.to_datetime(ims["timestamp_utc"], utc=True)
merged = labels.merge(ims, on="timestamp_utc", how="inner")
merged.sort_values("timestamp_utc", inplace=True)
merged.reset_index(drop=True, inplace=True)
return merged
# ------------------------------------------------------------------
# predict_df based forecasting
# ------------------------------------------------------------------
def forecast_day(
self,
df: pd.DataFrame,
context_end_idx: int,
prediction_length: int = STEPS_PER_DAY,
covariate_mode: str = "all",
) -> pd.DataFrame:
"""Forecast next prediction_length steps using predict_df API.
covariate_mode: 'none', 'ims', 'sensor', or 'all'
"""
mode_cfg = COVARIATE_MODES[covariate_mode]
past_cols = [c for c in mode_cfg["past"] if c in df.columns]
future_cols = [c for c in mode_cfg["future"] if c in df.columns]
ctx_start = max(0, context_end_idx - self.context_steps)
ctx = df.iloc[ctx_start:context_end_idx].copy()
# Build history DataFrame
hist = ctx[["timestamp_utc", "A"]].copy()
hist.rename(columns={"timestamp_utc": "timestamp", "A": "target"}, inplace=True)
hist["item_id"] = "A"
for col in past_cols:
hist[col] = ctx[col].values
# Build future covariates DataFrame
future_df = None
if future_cols:
fwd = df.iloc[context_end_idx : context_end_idx + prediction_length]
if len(fwd) >= prediction_length:
future_df = fwd[["timestamp_utc"]].copy()
future_df.rename(columns={"timestamp_utc": "timestamp"}, inplace=True)
future_df["item_id"] = "A"
for col in future_cols:
future_df[col] = fwd[col].values
result = self.pipeline.predict_df(
df=hist,
future_df=future_df,
id_column="item_id",
timestamp_column="timestamp",
target="target",
prediction_length=prediction_length,
quantile_levels=[0.1, 0.5, 0.9],
)
fwd_timestamps = df["timestamp_utc"].iloc[
context_end_idx : context_end_idx + prediction_length
].values
out = pd.DataFrame({
"timestamp_utc": fwd_timestamps[:len(result)],
"median": result["0.5"].values,
"low_10": result["0.1"].values,
"high_90": result["0.9"].values,
})
return out
# ------------------------------------------------------------------
# LoRA fine-tuning
# ------------------------------------------------------------------
def finetune(
self,
df: pd.DataFrame,
train_ratio: float = 0.75,
prediction_length: int = STEPS_PER_DAY,
covariate_mode: str = "all",
num_steps: int = 500,
learning_rate: float = 1e-5,
batch_size: Optional[int] = None,
output_dir: Optional[str] = None,
) -> None:
"""LoRA fine-tune Chronos-2 on the training portion of the data.
Uses the dict API for fit() with past and future covariates.
Only the training portion (before train_ratio split) is used —
no data leakage.
"""
split_idx = int(len(df) * train_ratio)
train_df = df.iloc[:split_idx].copy()
mode_cfg = COVARIATE_MODES[covariate_mode]
past_cols = [c for c in mode_cfg["past"] if c in df.columns]
future_cols = [c for c in mode_cfg["future"] if c in df.columns]
# Build training inputs: sliding windows over the training data
# Each window: context_steps history + prediction_length target
min_window = self.context_steps + prediction_length
inputs = []
# Sample windows every prediction_length steps for diversity
stride = prediction_length
for end_idx in range(min_window, len(train_df), stride):
ctx_start = end_idx - min_window
ctx_end = end_idx - prediction_length
target = train_df["A"].iloc[ctx_start:ctx_end].values.astype(np.float32)
entry: dict = {"target": target}
if past_cols:
past_covs = {}
for col in past_cols:
past_covs[col] = (
train_df[col].iloc[ctx_start:ctx_end].values.astype(np.float32)
)
entry["past_covariates"] = past_covs
if future_cols:
future_covs = {}
for col in future_cols:
# Use actual values from training data as future covariates
future_covs[col] = (
train_df[col].iloc[ctx_end:end_idx].values.astype(np.float32)
)
entry["future_covariates"] = future_covs
inputs.append(entry)
if not inputs:
print("Not enough training data for fine-tuning.")
return
# Build validation inputs from last 10% of training portion
val_split = int(len(inputs) * 0.9)
train_inputs = inputs[:val_split]
val_inputs = inputs[val_split:] if val_split < len(inputs) else None
output_dir = output_dir or str(OUTPUTS_DIR / "chronos_finetuned")
effective_batch = batch_size if batch_size is not None else min(32, len(train_inputs))
print(f"Fine-tuning with LoRA: {len(train_inputs)} train windows, "
f"{len(val_inputs) if val_inputs else 0} val windows, "
f"{num_steps} steps, batch_size={effective_batch}")
finetuned = self.pipeline.fit(
inputs=train_inputs,
prediction_length=prediction_length,
validation_inputs=val_inputs,
finetune_mode="lora",
learning_rate=learning_rate,
num_steps=num_steps,
batch_size=effective_batch,
output_dir=output_dir,
)
self.pipeline = finetuned
print(f"Fine-tuning complete. Model saved → {output_dir}")
# ------------------------------------------------------------------
# Walk-forward benchmark
# ------------------------------------------------------------------
def benchmark(
self,
df: Optional[pd.DataFrame] = None,
train_ratio: float = 0.75,
prediction_length: int = STEPS_PER_DAY,
max_test_days: Optional[int] = None,
covariate_modes: Optional[list[str]] = None,
) -> pd.DataFrame:
"""Walk-forward evaluation across covariate modes.
Predicts 96 steps (24h) on the regular grid, evaluates ONLY on
daytime steps where actual A > 0.
"""
if df is None:
df = self.load_data()
if covariate_modes is None:
covariate_modes = ["none", "ims", "sensor", "all"]
sparse = self.load_sparse_data()
daytime_timestamps = set(sparse["timestamp_utc"])
split_idx = int(len(df) * train_ratio)
test_starts = list(range(split_idx, len(df) - prediction_length, prediction_length))
if max_test_days is not None:
test_starts = test_starts[:max_test_days]
results = {}
for mode in covariate_modes:
all_actual, all_pred = [], []
for start_idx in test_starts:
forecast_df = self.forecast_day(
df, start_idx, prediction_length, covariate_mode=mode,
)
actual_slice = df.iloc[start_idx : start_idx + prediction_length]
if len(actual_slice) < prediction_length:
continue
daytime_mask = actual_slice["timestamp_utc"].isin(daytime_timestamps).values
daytime_mask = daytime_mask[:len(forecast_df)]
if daytime_mask.sum() < 5:
continue
actual_day = actual_slice["A"].values[:len(forecast_df)][daytime_mask]
pred_day = np.clip(forecast_df["median"].values[daytime_mask], 0, None)
all_actual.append(actual_day)
all_pred.append(pred_day)
if not all_actual:
continue
actual_flat = np.concatenate(all_actual)
pred_flat = np.concatenate(all_pred)
results[mode] = {
"MAE": round(float(mean_absolute_error(actual_flat, pred_flat)), 4),
"RMSE": round(
float(np.sqrt(mean_squared_error(actual_flat, pred_flat))), 4
),
"R2": round(float(r2_score(actual_flat, pred_flat)), 4),
"n_windows": len(all_actual),
"n_steps": len(actual_flat),
}
print(f" {mode:12s}: MAE={results[mode]['MAE']:.4f} "
f"RMSE={results[mode]['RMSE']:.4f} R²={results[mode]['R2']:.4f} "
f"({results[mode]['n_windows']} windows, "
f"{results[mode]['n_steps']} daytime steps)")
comparison = pd.DataFrame(results).T
comparison.index.name = "mode"
comparison.reset_index(inplace=True)
# Append ML baseline row for app comparison
ml_baseline = pd.DataFrame([{
"mode": "ML baseline (best)",
"MAE": 2.7,
"RMSE": np.nan,
"R2": np.nan,
"n_windows": np.nan,
"n_steps": np.nan,
}])
comparison = pd.concat([comparison, ml_baseline], ignore_index=True)
OUTPUTS_DIR.mkdir(parents=True, exist_ok=True)
comparison.to_csv(OUTPUTS_DIR / "chronos_benchmark.csv", index=False)
print(f"Saved benchmark → {OUTPUTS_DIR / 'chronos_benchmark.csv'}")
return comparison
# ------------------------------------------------------------------
# Sample forecast plot
# ------------------------------------------------------------------
def plot_sample_forecast(
self,
df: Optional[pd.DataFrame] = None,
test_day_idx: int = 0,
train_ratio: float = 0.75,
prediction_length: int = STEPS_PER_DAY,
) -> None:
"""Generate a sample forecast plot with confidence bands."""
import matplotlib.pyplot as plt
if df is None:
df = self.load_data()
split_idx = int(len(df) * train_ratio)
start_idx = split_idx + test_day_idx * prediction_length
if start_idx + prediction_length > len(df):
print("Not enough data for sample forecast plot.")
return
forecast_df = self.forecast_day(
df, start_idx, prediction_length, covariate_mode="all",
)
actual = df["A"].iloc[start_idx : start_idx + prediction_length].values
fig, ax = plt.subplots(figsize=(12, 5))
hours = np.arange(len(forecast_df)) * 0.25
ax.plot(hours, actual[:len(forecast_df)], "k-", linewidth=1.5, label="Actual A")
ax.plot(
hours, np.clip(forecast_df["median"].values, 0, None),
"b-", linewidth=1.5, label="Chronos-2 median",
)
ax.fill_between(
hours,
np.clip(forecast_df["low_10"].values, 0, None),
forecast_df["high_90"].values,
alpha=0.25, color="steelblue", label="10-90% CI",
)
ax.set_xlabel("Hours ahead")
ax.set_ylabel("A (umol CO2 m-2 s-1)")
ax.axhline(0, color="gray", linewidth=0.5, linestyle="--")
ts = df["timestamp_utc"].iloc[start_idx]
ax.set_title(f"Chronos-2 Day-Ahead Forecast — {ts:%Y-%m-%d %H:%M}")
ax.legend()
ax.grid(True, alpha=0.3)
OUTPUTS_DIR.mkdir(parents=True, exist_ok=True)
fig.savefig(
OUTPUTS_DIR / "chronos_forecast_sample.png", dpi=150, bbox_inches="tight",
)
plt.close(fig)
print(f"Saved plot → {OUTPUTS_DIR / 'chronos_forecast_sample.png'}")
# ----------------------------------------------------------------------
# CLI entry point
# ----------------------------------------------------------------------
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Chronos-2 day-ahead A forecasting")
parser.add_argument("--device", default="mps", help="torch device")
parser.add_argument("--context-days", type=int, default=14, help="context window in days")
parser.add_argument("--max-days", type=int, default=None, help="limit test windows")
parser.add_argument("--plot", action="store_true", help="generate sample forecast plot")
parser.add_argument(
"--finetune", action="store_true",
help="LoRA fine-tune before benchmarking",
)
parser.add_argument("--ft-steps", type=int, default=500, help="fine-tuning steps")
parser.add_argument(
"--modes", nargs="+", default=["none", "ims", "sensor", "all"],
help="covariate modes to benchmark",
)
args = parser.parse_args()
forecaster = ChronosForecaster(
device=args.device, context_days=args.context_days,
)
print("Loading data (growing-season grid + on-site sensors)...")
df = forecaster.load_data()
print(f" Grid: {len(df)} rows, seasons: {sorted(df['season'].unique())}")
if args.finetune:
print(f"\nLoRA fine-tuning ({args.ft_steps} steps)...")
forecaster.finetune(df, num_steps=args.ft_steps, covariate_mode="all")
print("\nRunning walk-forward benchmark (daytime-only evaluation)...")
results = forecaster.benchmark(
df, max_test_days=args.max_days, covariate_modes=args.modes,
)
print(f"\n{results.to_string(index=False)}")
if args.plot:
print("\nGenerating sample forecast plot...")
forecaster.plot_sample_forecast(df)