| import argparse |
| import logging |
| import warnings |
| from pathlib import Path |
|
|
| import matplotlib |
| from gluonts.model.evaluation import evaluate_model |
| from gluonts.time_feature import get_seasonality |
| from linear_operator.utils.cholesky import NumericalWarning |
|
|
| from src.gift_eval.constants import ( |
| DATASET_PROPERTIES, |
| MED_LONG_DATASETS, |
| METRICS, |
| PRETTY_NAMES, |
| ) |
| from src.gift_eval.core import DatasetMetadata, EvaluationItem, expand_datasets_arg |
| from src.gift_eval.data import Dataset |
| from src.gift_eval.predictor import TimeSeriesPredictor |
| from src.gift_eval.results import write_results_to_disk |
| from src.plotting.gift_eval_utils import create_plots_for_dataset |
|
|
| logger = logging.getLogger(__name__) |
|
|
| |
| warnings.filterwarnings("ignore", category=NumericalWarning) |
| warnings.filterwarnings("ignore", category=FutureWarning) |
| warnings.filterwarnings("ignore", category=DeprecationWarning) |
| matplotlib.set_loglevel("WARNING") |
| logging.getLogger("matplotlib").setLevel(logging.WARNING) |
| logging.getLogger("matplotlib.font_manager").setLevel(logging.WARNING) |
| logging.getLogger("PIL").setLevel(logging.WARNING) |
|
|
|
|
| class WarningFilter(logging.Filter): |
| def __init__(self, text_to_filter: str) -> None: |
| super().__init__() |
| self.text_to_filter = text_to_filter |
|
|
| def filter(self, record: logging.LogRecord) -> bool: |
| return self.text_to_filter not in record.getMessage() |
|
|
|
|
| |
| gts_logger = logging.getLogger("gluonts.model.forecast") |
| gts_logger.addFilter(WarningFilter("The mean prediction is not stored in the forecast data")) |
|
|
|
|
| def construct_evaluation_data( |
| dataset_name: str, |
| dataset_storage_path: str, |
| terms: list[str] | None = None, |
| max_windows: int | None = None, |
| ) -> list[tuple[Dataset, DatasetMetadata]]: |
| """Build datasets and rich metadata per term for a dataset name.""" |
| if terms is None: |
| terms = ["short", "medium", "long"] |
|
|
| sub_datasets: list[tuple[Dataset, DatasetMetadata]] = [] |
|
|
| if "/" in dataset_name: |
| ds_key, ds_freq = dataset_name.split("/") |
| ds_key = ds_key.lower() |
| ds_key = PRETTY_NAMES.get(ds_key, ds_key) |
| else: |
| ds_key = dataset_name.lower() |
| ds_key = PRETTY_NAMES.get(ds_key, ds_key) |
| ds_freq = DATASET_PROPERTIES.get(ds_key, {}).get("frequency") |
|
|
| for term in terms: |
| |
| if (term == "medium" or term == "long") and dataset_name not in MED_LONG_DATASETS: |
| continue |
|
|
| |
| probe_dataset = Dataset( |
| name=dataset_name, |
| term=term, |
| to_univariate=False, |
| storage_path=dataset_storage_path, |
| max_windows=max_windows, |
| ) |
|
|
| to_univariate = probe_dataset.target_dim > 1 |
|
|
| dataset = Dataset( |
| name=dataset_name, |
| term=term, |
| to_univariate=to_univariate, |
| storage_path=dataset_storage_path, |
| max_windows=max_windows, |
| ) |
|
|
| |
| season_length = get_seasonality(dataset.freq) |
| actual_freq = ds_freq if ds_freq else dataset.freq |
|
|
| metadata = DatasetMetadata( |
| full_name=f"{ds_key}/{actual_freq}/{term}", |
| key=ds_key, |
| freq=actual_freq, |
| term=term, |
| season_length=season_length, |
| target_dim=probe_dataset.target_dim, |
| to_univariate=to_univariate, |
| prediction_length=dataset.prediction_length, |
| windows=dataset.windows, |
| ) |
|
|
| sub_datasets.append((dataset, metadata)) |
|
|
| return sub_datasets |
|
|
|
|
| def evaluate_datasets( |
| predictor: TimeSeriesPredictor, |
| dataset: str, |
| dataset_storage_path: str, |
| terms: list[str] | None = None, |
| max_windows: int | None = None, |
| batch_size: int = 48, |
| max_context_length: int | None = 1024, |
| create_plots: bool = False, |
| max_plots_per_dataset: int = 10, |
| ) -> list[EvaluationItem]: |
| """Evaluate predictor on one dataset across the requested terms.""" |
| if terms is None: |
| terms = ["short", "medium", "long"] |
|
|
| sub_datasets = construct_evaluation_data( |
| dataset_name=dataset, |
| dataset_storage_path=dataset_storage_path, |
| terms=terms, |
| max_windows=max_windows, |
| ) |
|
|
| results: list[EvaluationItem] = [] |
| for i, (sub_dataset, metadata) in enumerate(sub_datasets): |
| logger.info(f"Evaluating {i + 1}/{len(sub_datasets)}: {metadata.full_name}") |
| logger.info(f" Dataset size: {len(sub_dataset.test_data)}") |
| logger.info(f" Frequency: {sub_dataset.freq}") |
| logger.info(f" Term: {metadata.term}") |
| logger.info(f" Prediction length: {sub_dataset.prediction_length}") |
| logger.info(f" Target dimensions: {sub_dataset.target_dim}") |
| logger.info(f" Windows: {sub_dataset.windows}") |
|
|
| |
| predictor.set_dataset_context( |
| prediction_length=sub_dataset.prediction_length, |
| freq=sub_dataset.freq, |
| batch_size=batch_size, |
| max_context_length=max_context_length, |
| ) |
|
|
| res = evaluate_model( |
| model=predictor, |
| test_data=sub_dataset.test_data, |
| metrics=METRICS, |
| axis=None, |
| mask_invalid_label=True, |
| allow_nan_forecast=False, |
| seasonality=metadata.season_length, |
| ) |
|
|
| figs: list[tuple[object, str]] = [] |
| if create_plots: |
| forecasts = predictor.predict(sub_dataset.test_data.input) |
| figs = create_plots_for_dataset( |
| forecasts=forecasts, |
| test_data=sub_dataset.test_data, |
| dataset_metadata=metadata, |
| max_plots=max_plots_per_dataset, |
| max_context_length=max_context_length, |
| ) |
|
|
| results.append(EvaluationItem(dataset_metadata=metadata, metrics=res, figures=figs)) |
|
|
| return results |
|
|
|
|
| def _run_evaluation( |
| predictor: TimeSeriesPredictor, |
| datasets: list[str] | str, |
| terms: list[str], |
| dataset_storage_path: str, |
| max_windows: int | None = None, |
| batch_size: int = 48, |
| max_context_length: int | None = 1024, |
| output_dir: str = "gift_eval_results", |
| model_name: str = "TimeSeriesModel", |
| create_plots: bool = False, |
| max_plots: int = 10, |
| ) -> None: |
| """Shared evaluation workflow used by both entry points.""" |
| datasets_to_run = expand_datasets_arg(datasets) |
| results_root = Path(output_dir) |
|
|
| for ds_name in datasets_to_run: |
| items = evaluate_datasets( |
| predictor=predictor, |
| dataset=ds_name, |
| dataset_storage_path=dataset_storage_path, |
| terms=terms, |
| max_windows=max_windows, |
| batch_size=batch_size, |
| max_context_length=max_context_length, |
| create_plots=create_plots, |
| max_plots_per_dataset=max_plots, |
| ) |
| write_results_to_disk( |
| items=items, |
| dataset_name=ds_name, |
| output_dir=results_root, |
| model_name=model_name, |
| create_plots=create_plots, |
| ) |
|
|
|
|
| def evaluate_from_paths( |
| model_path: str, |
| config_path: str, |
| datasets: list[str] | str, |
| terms: list[str], |
| dataset_storage_path: str, |
| max_windows: int | None = None, |
| batch_size: int = 48, |
| max_context_length: int | None = 1024, |
| output_dir: str = "gift_eval_results", |
| model_name: str = "TimeSeriesModel", |
| create_plots: bool = False, |
| max_plots: int = 10, |
| ) -> None: |
| """Entry point: load model from disk and save metrics/plots to disk.""" |
| |
| if not Path(model_path).exists(): |
| raise FileNotFoundError(f"Model path does not exist: {model_path}") |
| if not Path(config_path).exists(): |
| raise FileNotFoundError(f"Config path does not exist: {config_path}") |
|
|
| predictor = TimeSeriesPredictor.from_paths( |
| model_path=model_path, |
| config_path=config_path, |
| ds_prediction_length=1, |
| ds_freq="D", |
| batch_size=batch_size, |
| max_context_length=max_context_length, |
| ) |
|
|
| _run_evaluation( |
| predictor=predictor, |
| datasets=datasets, |
| terms=terms, |
| dataset_storage_path=dataset_storage_path, |
| max_windows=max_windows, |
| batch_size=batch_size, |
| max_context_length=max_context_length, |
| output_dir=output_dir, |
| model_name=model_name, |
| create_plots=create_plots, |
| max_plots=max_plots, |
| ) |
|
|
|
|
| def evaluate_in_memory( |
| model, |
| config: dict, |
| datasets: list[str] | str, |
| terms: list[str], |
| dataset_storage_path: str, |
| max_windows: int | None = None, |
| batch_size: int = 48, |
| max_context_length: int | None = 1024, |
| output_dir: str = "gift_eval_results", |
| model_name: str = "TimeSeriesModel", |
| create_plots: bool = False, |
| max_plots: int = 10, |
| ) -> None: |
| """Entry point: evaluate in-memory model and return results per dataset.""" |
| predictor = TimeSeriesPredictor.from_model( |
| model=model, |
| config=config, |
| ds_prediction_length=1, |
| ds_freq="D", |
| batch_size=batch_size, |
| max_context_length=max_context_length, |
| ) |
|
|
| _run_evaluation( |
| predictor=predictor, |
| datasets=datasets, |
| terms=terms, |
| dataset_storage_path=dataset_storage_path, |
| max_windows=max_windows, |
| batch_size=batch_size, |
| max_context_length=max_context_length, |
| output_dir=output_dir, |
| model_name=model_name, |
| create_plots=create_plots, |
| max_plots=max_plots, |
| ) |
|
|
|
|
| def _parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser(description="Evaluate TimeSeriesModel on GIFT-Eval datasets") |
|
|
| |
| parser.add_argument( |
| "--model_path", |
| type=str, |
| required=True, |
| help="Path to the trained model checkpoint", |
| ) |
| parser.add_argument( |
| "--config_path", |
| type=str, |
| required=True, |
| help="Path to the model configuration YAML file", |
| ) |
| parser.add_argument( |
| "--model_name", |
| type=str, |
| default="TimeSeriesModel", |
| help="Name identifier for the model", |
| ) |
|
|
| |
| parser.add_argument( |
| "--datasets", |
| type=str, |
| default="all", |
| help="Comma-separated list of dataset names to evaluate (or 'all')", |
| ) |
| parser.add_argument( |
| "--dataset_storage_path", |
| type=str, |
| default="/work/dlclarge2/moroshav-GiftEvalPretrain/gift_eval", |
| help="Path to the dataset storage directory (default: GIFT_EVAL)", |
| ) |
| parser.add_argument( |
| "--terms", |
| type=str, |
| default="short,medium,long", |
| help="Comma-separated list of prediction terms to evaluate", |
| ) |
| parser.add_argument( |
| "--max_windows", |
| type=int, |
| default=None, |
| help="Maximum number of windows to use for evaluation", |
| ) |
|
|
| |
| parser.add_argument("--batch_size", type=int, default=48, help="Batch size for model inference") |
| parser.add_argument( |
| "--max_context_length", |
| type=int, |
| default=1024, |
| help="Maximum context length to use (None for no limit)", |
| ) |
|
|
| |
| parser.add_argument( |
| "--output_dir", |
| type=str, |
| default="gift_eval_results", |
| help="Directory to save evaluation results", |
| ) |
|
|
| |
| parser.add_argument( |
| "--create_plots", |
| action="store_true", |
| help="Create and save plots for each evaluation window", |
| ) |
| parser.add_argument( |
| "--max_plots_per_dataset", |
| type=int, |
| default=10, |
| help="Maximum number of plots to create per dataset term", |
| ) |
|
|
| args = parser.parse_args() |
| args.terms = args.terms.split(",") |
| args.datasets = args.datasets.split(",") |
| return args |
|
|
|
|
| def _configure_logging() -> None: |
| logging.basicConfig( |
| level=logging.INFO, |
| format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| _configure_logging() |
| args = _parse_args() |
| logger.info(f"Command Line Arguments: {vars(args)}") |
| try: |
| evaluate_from_paths( |
| model_path=args.model_path, |
| config_path=args.config_path, |
| datasets=args.datasets, |
| terms=args.terms, |
| dataset_storage_path=args.dataset_storage_path, |
| max_windows=args.max_windows, |
| batch_size=args.batch_size, |
| max_context_length=args.max_context_length, |
| output_dir=args.output_dir, |
| model_name=args.model_name, |
| create_plots=args.create_plots, |
| max_plots=args.max_plots_per_dataset, |
| ) |
| except Exception as e: |
| logger.error(f"Evaluation failed: {str(e)}") |
| raise |
|
|