File size: 5,393 Bytes
938949f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
"""
PhotosynthesisPredictor: train and evaluate regression models on IMS
features; report RMSE, MAE, R2.
"""

from __future__ import annotations

from pathlib import Path
from typing import Optional

import numpy as np
import pandas as pd
from sklearn.linear_model import LinearRegression
from sklearn.tree import DecisionTreeRegressor
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score

try:
    from xgboost import XGBRegressor
    _HAS_XGB = True
except ImportError:
    _HAS_XGB = False

try:
    import matplotlib.pyplot as plt
    _HAS_PLOT = True
except ImportError:
    _HAS_PLOT = False


class PhotosynthesisPredictor:
    """Train multiple regressors and evaluate on test set."""

    def __init__(self):
        self.models: dict = {
            "LinearRegression": LinearRegression(),
            "DecisionTree": DecisionTreeRegressor(max_depth=6, min_samples_leaf=10),
            "RandomForest": RandomForestRegressor(
                n_estimators=200, max_depth=8, min_samples_leaf=5,
                n_jobs=-1, random_state=42,
            ),
            "GradientBoosting": GradientBoostingRegressor(
                n_estimators=300, max_depth=4, learning_rate=0.05,
                min_samples_leaf=10, random_state=42,
            ),
        }
        if _HAS_XGB:
            self.models["XGBoost"] = XGBRegressor(
                n_estimators=300, max_depth=4, learning_rate=0.05,
                min_child_weight=10, reg_alpha=0.1, reg_lambda=1.0,
                n_jobs=-1, random_state=42,
            )
        self.results: dict[str, dict] = {}

    def train(self, X_train: pd.DataFrame, y_train: pd.Series) -> None:
        """Fit all models on (X_train, y_train)."""
        for name, model in self.models.items():
            model.fit(X_train, y_train)

    def evaluate(
        self,
        X_test: pd.DataFrame,
        y_test: pd.Series,
    ) -> pd.DataFrame:
        """
        Predict with each model, compute RMSE, MAE, R2. Return comparison table.
        """
        rows = []
        for name, model in self.models.items():
            pred = model.predict(X_test)
            rmse = float(np.sqrt(mean_squared_error(y_test, pred)))
            mae = float(mean_absolute_error(y_test, pred))
            r2 = float(r2_score(y_test, pred))
            self.results[name] = {"predictions": pred, "rmse": rmse, "mae": mae, "r2": r2}
            rows.append({"model": name, "RMSE": rmse, "MAE": mae, "R2": r2})
        return pd.DataFrame(rows)

    def get_feature_importance(self, model_name: str | None = None) -> pd.DataFrame:
        """
        Return feature importance from tree-based models.
        Prefers XGBoost > GradientBoosting > RandomForest > DecisionTree.
        """
        if model_name:
            candidates = [model_name]
        else:
            candidates = ["XGBoost", "GradientBoosting", "RandomForest", "DecisionTree"]
        for name in candidates:
            m = self.models.get(name)
            if m is not None and hasattr(m, "feature_importances_"):
                imp = m.feature_importances_
                return pd.DataFrame({
                    "feature": getattr(m, "feature_names_in_", list(range(len(imp)))),
                    "importance": imp,
                }).sort_values("importance", ascending=False)
        return pd.DataFrame()

    def plot_results(
        self,
        y_test: pd.Series,
        predictions: Optional[dict[str, np.ndarray]] = None,
        save_path: Optional[Path] = None,
    ) -> None:
        """
        Predicted vs approx A scatter and optional time series overlay.
        predictions: dict model_name -> pred array; if None use self.results.
        """
        if not _HAS_PLOT:
            return
        preds = predictions or {n: self.results[n]["predictions"] for n in self.results}
        if not preds:
            return
        fig, axes = plt.subplots(1, 2, figsize=(12, 5))
        # Scatter: pick best model by R2
        best = max(self.results, key=lambda n: self.results[n].get("r2", -999)) if self.results else list(preds.keys())[0]
        name = best if best in preds else list(preds.keys())[0]
        ax = axes[0]
        ax.scatter(y_test, preds[name], alpha=0.5, s=10)
        mn = min(y_test.min(), preds[name].min())
        mx = max(y_test.max(), preds[name].max())
        ax.plot([mn, mx], [mn, mx], "k--", label="1:1")
        ax.set_xlabel("Approx A (µmol m⁻² s⁻¹)")
        ax.set_ylabel("Predicted A")
        ax.set_title(f"Predicted vs approx A ({name})")
        ax.legend()
        ax.set_aspect("equal")
        # Time series overlay — show top 2 models by R2
        ax = axes[1]
        ax.plot(y_test.values, label="Approx A", alpha=0.8)
        ranked = sorted(self.results, key=lambda n: self.results[n].get("r2", -999), reverse=True)
        for n in ranked[:2]:
            if n in preds:
                ax.plot(preds[n], label=f"{n} (R²={self.results[n]['r2']:.2f})", alpha=0.7)
        ax.set_xlabel("Time index")
        ax.set_ylabel("A (umol m-2 s-1)")
        ax.set_title("Time series overlay")
        ax.legend()
        plt.tight_layout()
        if save_path:
            save_path.parent.mkdir(parents=True, exist_ok=True)
            plt.savefig(save_path, dpi=150)
        plt.close()