| | """Gradio front-end for Fault_Classification_PMU_Data models. |
| | |
| | The application loads a CNN-LSTM model (and accompanying scaler/metadata) |
| | produced by ``fault_classification_pmu.py`` and exposes a streamlined |
| | prediction interface optimised for Hugging Face Spaces deployment. It supports |
| | raw PMU time-series CSV uploads as well as manual comma separated feature |
| | vectors. |
| | """ |
| |
|
| | from __future__ import annotations |
| |
|
| | import json |
| | import os |
| | import shutil |
| |
|
| | os.environ.setdefault("CUDA_VISIBLE_DEVICES", "-1") |
| | os.environ.setdefault("TF_CPP_MIN_LOG_LEVEL", "2") |
| | os.environ.setdefault("TF_ENABLE_ONEDNN_OPTS", "0") |
| |
|
| | import re |
| | from pathlib import Path |
| | from typing import Any, Dict, List, Optional, Sequence, Tuple, Union |
| |
|
| | import gradio as gr |
| | import joblib |
| | import numpy as np |
| | import pandas as pd |
| | import requests |
| | from huggingface_hub import hf_hub_download |
| | from tensorflow.keras.models import load_model |
| |
|
| | from fault_classification_pmu import ( |
| | DEFAULT_FEATURE_COLUMNS as TRAINING_DEFAULT_FEATURE_COLUMNS, |
| | LABEL_GUESS_CANDIDATES as TRAINING_LABEL_GUESSES, |
| | train_from_dataframe, |
| | ) |
| |
|
| | |
| | |
| | |
| | DEFAULT_FEATURE_COLUMNS: List[str] = list(TRAINING_DEFAULT_FEATURE_COLUMNS) |
| | DEFAULT_SEQUENCE_LENGTH = 32 |
| | DEFAULT_STRIDE = 4 |
| |
|
| | LOCAL_MODEL_FILE = os.environ.get("PMU_MODEL_FILE", "pmu_cnn_lstm_model.keras") |
| | LOCAL_SCALER_FILE = os.environ.get("PMU_SCALER_FILE", "pmu_feature_scaler.pkl") |
| | LOCAL_METADATA_FILE = os.environ.get("PMU_METADATA_FILE", "pmu_metadata.json") |
| |
|
| | MODEL_OUTPUT_DIR = Path(os.environ.get("PMU_MODEL_DIR", "model")).resolve() |
| | MODEL_OUTPUT_DIR.mkdir(parents=True, exist_ok=True) |
| |
|
| | HUB_REPO = os.environ.get("PMU_HUB_REPO", "") |
| | HUB_MODEL_FILENAME = os.environ.get("PMU_HUB_MODEL_FILENAME", LOCAL_MODEL_FILE) |
| | HUB_SCALER_FILENAME = os.environ.get("PMU_HUB_SCALER_FILENAME", LOCAL_SCALER_FILE) |
| | HUB_METADATA_FILENAME = os.environ.get("PMU_HUB_METADATA_FILENAME", LOCAL_METADATA_FILE) |
| |
|
| | ENV_MODEL_PATH = "PMU_MODEL_PATH" |
| | ENV_SCALER_PATH = "PMU_SCALER_PATH" |
| | ENV_METADATA_PATH = "PMU_METADATA_PATH" |
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | def download_from_hub(filename: str) -> Optional[Path]: |
| | if not HUB_REPO or not filename: |
| | return None |
| | try: |
| | print(f"Downloading {filename} from {HUB_REPO} ...") |
| | |
| | path = hf_hub_download(repo_id=HUB_REPO, filename=filename) |
| | print("Downloaded", path) |
| | return Path(path) |
| | except Exception as exc: |
| | print("Failed to download", filename, "from", HUB_REPO, ":", exc) |
| | print("Continuing without pre-trained model...") |
| | return None |
| |
|
| |
|
| | def resolve_artifact( |
| | local_name: str, env_var: str, hub_filename: str |
| | ) -> Optional[Path]: |
| | print(f"Resolving artifact: {local_name}, env: {env_var}, hub: {hub_filename}") |
| | candidates = [Path(local_name)] if local_name else [] |
| | if local_name: |
| | candidates.append(MODEL_OUTPUT_DIR / Path(local_name).name) |
| | env_value = os.environ.get(env_var) |
| | if env_value: |
| | candidates.append(Path(env_value)) |
| |
|
| | for candidate in candidates: |
| | if candidate and candidate.exists(): |
| | print(f"Found local artifact: {candidate}") |
| | return candidate |
| |
|
| | print(f"No local artifacts found, checking hub...") |
| | |
| | if HUB_REPO: |
| | return download_from_hub(hub_filename) |
| | else: |
| | print("No HUB_REPO configured, skipping download") |
| | return None |
| |
|
| |
|
| | def load_metadata(path: Optional[Path]) -> Dict: |
| | if path and path.exists(): |
| | try: |
| | return json.loads(path.read_text()) |
| | except Exception as exc: |
| | print("Failed to read metadata", path, exc) |
| | return {} |
| |
|
| |
|
| | def try_load_scaler(path: Optional[Path]): |
| | if not path: |
| | return None |
| | try: |
| | scaler = joblib.load(path) |
| | print("Loaded scaler from", path) |
| | return scaler |
| | except Exception as exc: |
| | print("Failed to load scaler", path, exc) |
| | return None |
| |
|
| |
|
| | |
| | print("Starting application initialization...") |
| | try: |
| | MODEL_PATH = resolve_artifact(LOCAL_MODEL_FILE, ENV_MODEL_PATH, HUB_MODEL_FILENAME) |
| | print(f"Model path resolved: {MODEL_PATH}") |
| | except Exception as e: |
| | print(f"Model path resolution failed: {e}") |
| | MODEL_PATH = None |
| |
|
| | try: |
| | SCALER_PATH = resolve_artifact( |
| | LOCAL_SCALER_FILE, ENV_SCALER_PATH, HUB_SCALER_FILENAME |
| | ) |
| | print(f"Scaler path resolved: {SCALER_PATH}") |
| | except Exception as e: |
| | print(f"Scaler path resolution failed: {e}") |
| | SCALER_PATH = None |
| |
|
| | try: |
| | METADATA_PATH = resolve_artifact( |
| | LOCAL_METADATA_FILE, ENV_METADATA_PATH, HUB_METADATA_FILENAME |
| | ) |
| | print(f"Metadata path resolved: {METADATA_PATH}") |
| | except Exception as e: |
| | print(f"Metadata path resolution failed: {e}") |
| | METADATA_PATH = None |
| |
|
| | try: |
| | METADATA = load_metadata(METADATA_PATH) |
| | print(f"Metadata loaded: {len(METADATA)} entries") |
| | except Exception as e: |
| | print(f"Metadata loading failed: {e}") |
| | METADATA = {} |
| |
|
| | |
| | QUEUE_MAX_SIZE = 32 |
| | |
| | |
| | EVENT_CONCURRENCY_LIMIT = 2 |
| |
|
| |
|
| | def try_load_model(path: Optional[Path], model_type: str, model_format: str): |
| | if not path: |
| | return None |
| | try: |
| | if model_type == "svm" or model_format == "joblib": |
| | model = joblib.load(path) |
| | else: |
| | model = load_model(path) |
| | print("Loaded model from", path) |
| | return model |
| | except Exception as exc: |
| | print("Failed to load model", path, exc) |
| | return None |
| |
|
| |
|
| | FEATURE_COLUMNS: List[str] = list(DEFAULT_FEATURE_COLUMNS) |
| | LABEL_CLASSES: List[str] = [] |
| | LABEL_COLUMN: str = "Fault" |
| | SEQUENCE_LENGTH: int = DEFAULT_SEQUENCE_LENGTH |
| | DEFAULT_WINDOW_STRIDE: int = DEFAULT_STRIDE |
| | MODEL_TYPE: str = "cnn_lstm" |
| | MODEL_FORMAT: str = "keras" |
| |
|
| |
|
| | def _model_output_path(filename: str) -> str: |
| | return str(MODEL_OUTPUT_DIR / Path(filename).name) |
| |
|
| |
|
| | MODEL_FILENAME_BY_TYPE: Dict[str, str] = { |
| | "cnn_lstm": Path(LOCAL_MODEL_FILE).name, |
| | "tcn": "pmu_tcn_model.keras", |
| | "svm": "pmu_svm_model.joblib", |
| | } |
| |
|
| | REQUIRED_PMU_COLUMNS: Tuple[str, ...] = tuple(DEFAULT_FEATURE_COLUMNS) |
| | TRAINING_UPLOAD_DIR = Path( |
| | os.environ.get("PMU_TRAINING_UPLOAD_DIR", "training_uploads") |
| | ) |
| | TRAINING_UPLOAD_DIR.mkdir(parents=True, exist_ok=True) |
| |
|
| | TRAINING_DATA_REPO = os.environ.get( |
| | "PMU_TRAINING_DATA_REPO", "VincentCroft/ThesisModelData" |
| | ) |
| | TRAINING_DATA_BRANCH = os.environ.get("PMU_TRAINING_DATA_BRANCH", "main") |
| | TRAINING_DATA_DIR = Path(os.environ.get("PMU_TRAINING_DATA_DIR", "training_dataset")) |
| | TRAINING_DATA_DIR.mkdir(parents=True, exist_ok=True) |
| |
|
| | GITHUB_CONTENT_CACHE: Dict[str, List[Dict[str, Any]]] = {} |
| |
|
| |
|
| | APP_CSS = """ |
| | #available-files-section { |
| | position: relative; |
| | display: flex; |
| | flex-direction: column; |
| | gap: 0.75rem; |
| | border-radius: 0.75rem; |
| | isolation: isolate; |
| | } |
| | |
| | #available-files-grid { |
| | position: relative; |
| | overflow: visible; |
| | } |
| | |
| | #available-files-grid .form { |
| | position: static; |
| | min-height: 16rem; |
| | } |
| | |
| | #available-files-grid .wrap { |
| | display: grid; |
| | grid-template-columns: repeat(4, minmax(0, 1fr)); |
| | gap: 0.5rem; |
| | max-height: 24rem; |
| | min-height: 16rem; |
| | overflow-y: auto; |
| | padding-right: 0.25rem; |
| | } |
| | |
| | #available-files-grid .wrap > div { |
| | min-width: 0; |
| | } |
| | |
| | #available-files-grid .wrap label { |
| | margin: 0; |
| | display: flex; |
| | align-items: center; |
| | padding: 0.45rem 0.65rem; |
| | border-radius: 0.65rem; |
| | background-color: rgba(255, 255, 255, 0.05); |
| | border: 1px solid rgba(255, 255, 255, 0.08); |
| | transition: background-color 0.2s ease, border-color 0.2s ease; |
| | min-height: 2.5rem; |
| | } |
| | |
| | #available-files-grid .wrap label:hover { |
| | background-color: rgba(90, 200, 250, 0.16); |
| | border-color: rgba(90, 200, 250, 0.4); |
| | } |
| | |
| | #available-files-grid .wrap label span { |
| | overflow: hidden; |
| | text-overflow: ellipsis; |
| | white-space: nowrap; |
| | } |
| | |
| | #available-files-section .gradio-loading, |
| | #available-files-grid .gradio-loading { |
| | position: absolute; |
| | top: 0; |
| | left: 0; |
| | right: 0; |
| | bottom: 0; |
| | width: 100%; |
| | height: 100%; |
| | display: flex; |
| | align-items: center; |
| | justify-content: center; |
| | background: rgba(10, 14, 23, 0.92); |
| | border-radius: 0.75rem; |
| | z-index: 999; |
| | padding: 1.5rem; |
| | pointer-events: auto; |
| | } |
| | |
| | #available-files-section .gradio-loading { |
| | position: absolute; |
| | inset: 0; |
| | width: 100%; |
| | height: 100%; |
| | display: flex; |
| | align-items: center; |
| | justify-content: center; |
| | background: rgba(10, 14, 23, 0.92); |
| | border-radius: 0.75rem; |
| | z-index: 999; |
| | padding: 1.5rem; |
| | pointer-events: auto; |
| | } |
| | |
| | #available-files-section .gradio-loading > * { |
| | width: 100%; |
| | } |
| | |
| | #available-files-section .gradio-loading progress, |
| | #available-files-section .gradio-loading .progress-bar, |
| | #available-files-section .gradio-loading .loading-progress, |
| | #available-files-section .gradio-loading [role="progressbar"], |
| | #available-files-section .gradio-loading .wrap, |
| | #available-files-section .gradio-loading .inner { |
| | width: 100% !important; |
| | max-width: none !important; |
| | } |
| | |
| | #available-files-section .gradio-loading .status, |
| | #available-files-section .gradio-loading .message, |
| | #available-files-section .gradio-loading .label { |
| | text-align: center; |
| | } |
| | |
| | #date-browser-row { |
| | gap: 0.75rem; |
| | } |
| | |
| | #date-browser-row .date-browser-column { |
| | flex: 1 1 0%; |
| | min-width: 0; |
| | } |
| | |
| | #date-browser-row .date-browser-column > .gradio-dropdown, |
| | #date-browser-row .date-browser-column > .gradio-button { |
| | width: 100%; |
| | } |
| | |
| | #date-browser-row .date-browser-column > .gradio-dropdown > div { |
| | width: 100%; |
| | } |
| | |
| | #date-browser-row .date-browser-column .gradio-button { |
| | justify-content: center; |
| | } |
| | |
| | #training-files-summary textarea { |
| | max-height: 12rem; |
| | overflow-y: auto; |
| | } |
| | |
| | #download-selected-button { |
| | width: 100%; |
| | position: relative; |
| | z-index: 0; |
| | } |
| | |
| | #download-selected-button .gradio-button { |
| | width: 100%; |
| | justify-content: center; |
| | } |
| | |
| | #artifact-download-row { |
| | gap: 0.75rem; |
| | } |
| | |
| | #artifact-download-row .artifact-download-button { |
| | flex: 1 1 0%; |
| | min-width: 0; |
| | } |
| | |
| | #artifact-download-row .artifact-download-button .gradio-button { |
| | width: 100%; |
| | justify-content: center; |
| | } |
| | """ |
| |
|
| |
|
| | def _github_cache_key(path: str) -> str: |
| | return path or "__root__" |
| |
|
| |
|
| | def _github_api_url(path: str) -> str: |
| | clean_path = path.strip("/") |
| | base = f"https://api.github.com/repos/{TRAINING_DATA_REPO}/contents" |
| | if clean_path: |
| | return f"{base}/{clean_path}?ref={TRAINING_DATA_BRANCH}" |
| | return f"{base}?ref={TRAINING_DATA_BRANCH}" |
| |
|
| |
|
| | def list_remote_directory( |
| | path: str = "", *, force_refresh: bool = False |
| | ) -> List[Dict[str, Any]]: |
| | key = _github_cache_key(path) |
| | if not force_refresh and key in GITHUB_CONTENT_CACHE: |
| | return GITHUB_CONTENT_CACHE[key] |
| |
|
| | url = _github_api_url(path) |
| | response = requests.get(url, timeout=30) |
| | if response.status_code != 200: |
| | raise RuntimeError( |
| | f"GitHub API request failed for `{path or '.'}` (status {response.status_code})." |
| | ) |
| |
|
| | payload = response.json() |
| | if not isinstance(payload, list): |
| | raise RuntimeError( |
| | "Unexpected GitHub API payload. Expected a directory listing." |
| | ) |
| |
|
| | GITHUB_CONTENT_CACHE[key] = payload |
| | return payload |
| |
|
| |
|
| | def list_remote_years(force_refresh: bool = False) -> List[str]: |
| | entries = list_remote_directory("", force_refresh=force_refresh) |
| | years = [item["name"] for item in entries if item.get("type") == "dir"] |
| | return sorted(years) |
| |
|
| |
|
| | def list_remote_months(year: str, *, force_refresh: bool = False) -> List[str]: |
| | if not year: |
| | return [] |
| | entries = list_remote_directory(year, force_refresh=force_refresh) |
| | months = [item["name"] for item in entries if item.get("type") == "dir"] |
| | return sorted(months) |
| |
|
| |
|
| | def list_remote_days( |
| | year: str, month: str, *, force_refresh: bool = False |
| | ) -> List[str]: |
| | if not year or not month: |
| | return [] |
| | entries = list_remote_directory(f"{year}/{month}", force_refresh=force_refresh) |
| | days = [item["name"] for item in entries if item.get("type") == "dir"] |
| | return sorted(days) |
| |
|
| |
|
| | def list_remote_files( |
| | year: str, month: str, day: str, *, force_refresh: bool = False |
| | ) -> List[str]: |
| | if not year or not month or not day: |
| | return [] |
| | entries = list_remote_directory( |
| | f"{year}/{month}/{day}", force_refresh=force_refresh |
| | ) |
| | files = [item["name"] for item in entries if item.get("type") == "file"] |
| | return sorted(files) |
| |
|
| |
|
| | def download_repository_file(year: str, month: str, day: str, filename: str) -> Path: |
| | if not filename: |
| | raise ValueError("Filename cannot be empty when downloading repository data.") |
| |
|
| | relative_parts = [part for part in (year, month, day, filename) if part] |
| | if len(relative_parts) < 4: |
| | raise ValueError("Provide year, month, day, and filename to download a CSV.") |
| |
|
| | relative_path = "/".join(relative_parts) |
| | raw_url = ( |
| | f"https://raw.githubusercontent.com/{TRAINING_DATA_REPO}/" |
| | f"{TRAINING_DATA_BRANCH}/{relative_path}" |
| | ) |
| |
|
| | response = requests.get(raw_url, stream=True, timeout=120) |
| | if response.status_code != 200: |
| | raise RuntimeError( |
| | f"Failed to download `{relative_path}` (status {response.status_code})." |
| | ) |
| |
|
| | target_dir = TRAINING_DATA_DIR.joinpath(year, month, day) |
| | target_dir.mkdir(parents=True, exist_ok=True) |
| | target_path = target_dir / filename |
| |
|
| | with open(target_path, "wb") as handle: |
| | for chunk in response.iter_content(chunk_size=1 << 20): |
| | if chunk: |
| | handle.write(chunk) |
| |
|
| | return target_path |
| |
|
| |
|
| | def _normalise_header(name: str) -> str: |
| | return str(name).strip().lower() |
| |
|
| |
|
| | def guess_label_from_columns( |
| | columns: Sequence[str], preferred: Optional[str] = None |
| | ) -> Optional[str]: |
| | if not columns: |
| | return preferred |
| |
|
| | lookup = {_normalise_header(col): str(col) for col in columns} |
| |
|
| | if preferred: |
| | preferred_stripped = preferred.strip() |
| | for col in columns: |
| | if str(col).strip() == preferred_stripped: |
| | return str(col) |
| | preferred_norm = _normalise_header(preferred) |
| | if preferred_norm in lookup: |
| | return lookup[preferred_norm] |
| |
|
| | for guess in TRAINING_LABEL_GUESSES: |
| | guess_norm = _normalise_header(guess) |
| | if guess_norm in lookup: |
| | return lookup[guess_norm] |
| |
|
| | for col in columns: |
| | if _normalise_header(col).startswith("fault"): |
| | return str(col) |
| |
|
| | return str(columns[0]) |
| |
|
| |
|
| | def summarise_training_files(paths: Sequence[str], notes: Sequence[str]) -> str: |
| | lines = [Path(path).name for path in paths] |
| | lines.extend(notes) |
| | return "\n".join(lines) if lines else "No training files available." |
| |
|
| |
|
| | def read_training_status(status_file_path: str) -> str: |
| | """Read the current training status from file.""" |
| | try: |
| | if Path(status_file_path).exists(): |
| | with open(status_file_path, "r") as f: |
| | return f.read().strip() |
| | except Exception: |
| | pass |
| | return "Training status unavailable" |
| |
|
| |
|
| | def _persist_uploaded_file(file_obj) -> Optional[Path]: |
| | if file_obj is None: |
| | return None |
| |
|
| | if isinstance(file_obj, (str, Path)): |
| | source = Path(file_obj) |
| | original_name = source.name |
| | else: |
| | source = Path(getattr(file_obj, "name", "") or getattr(file_obj, "path", "")) |
| | original_name = getattr(file_obj, "orig_name", source.name) or source.name |
| | if not source or not source.exists(): |
| | return None |
| |
|
| | original_name = Path(original_name).name or source.name |
| |
|
| | base_path = Path(original_name) |
| | destination = TRAINING_UPLOAD_DIR / base_path.name |
| | counter = 1 |
| | while destination.exists(): |
| | suffix = base_path.suffix or ".csv" |
| | destination = TRAINING_UPLOAD_DIR / f"{base_path.stem}_{counter}{suffix}" |
| | counter += 1 |
| |
|
| | shutil.copy2(source, destination) |
| | return destination |
| |
|
| |
|
| | def prepare_training_paths( |
| | paths: Sequence[str], current_label: str, cleanup_missing: bool = False |
| | ): |
| | valid_paths: List[str] = [] |
| | notes: List[str] = [] |
| | columns_map: Dict[str, str] = {} |
| | for path in paths: |
| | try: |
| | df = load_measurement_csv(path) |
| | except Exception as exc: |
| | notes.append(f"⚠️ Skipped {Path(path).name}: {exc}") |
| | if cleanup_missing: |
| | try: |
| | Path(path).unlink(missing_ok=True) |
| | except Exception: |
| | pass |
| | continue |
| | valid_paths.append(str(path)) |
| | for col in df.columns: |
| | columns_map[_normalise_header(col)] = str(col) |
| |
|
| | summary = summarise_training_files(valid_paths, notes) |
| | preferred = current_label or LABEL_COLUMN |
| | dropdown_choices = ( |
| | sorted(columns_map.values()) if columns_map else [preferred or LABEL_COLUMN] |
| | ) |
| | guessed = guess_label_from_columns(dropdown_choices, preferred) |
| | dropdown_value = guessed or preferred or LABEL_COLUMN |
| |
|
| | return ( |
| | valid_paths, |
| | summary, |
| | gr.update(choices=dropdown_choices, value=dropdown_value), |
| | ) |
| |
|
| |
|
| | def append_training_files(new_files, existing_paths: Sequence[str], current_label: str): |
| | if isinstance(existing_paths, (str, Path)): |
| | paths: List[str] = [str(existing_paths)] |
| | elif existing_paths is None: |
| | paths = [] |
| | else: |
| | paths = list(existing_paths) |
| | if new_files: |
| | for file in new_files: |
| | persisted = _persist_uploaded_file(file) |
| | if persisted is None: |
| | continue |
| | path_str = str(persisted) |
| | if path_str not in paths: |
| | paths.append(path_str) |
| |
|
| | return prepare_training_paths(paths, current_label, cleanup_missing=True) |
| |
|
| |
|
| | def load_repository_training_files(current_label: str, force_refresh: bool = False): |
| | if force_refresh: |
| | |
| | for cached in list(TRAINING_DATA_DIR.glob("*")): |
| | |
| | |
| | break |
| |
|
| | csv_paths = sorted( |
| | str(path) for path in TRAINING_DATA_DIR.rglob("*.csv") if path.is_file() |
| | ) |
| | if not csv_paths: |
| | message = ( |
| | "No local database CSVs are available yet. Use the database browser " |
| | "below to download specific days before training." |
| | ) |
| | default_label = current_label or LABEL_COLUMN or "Fault" |
| | return ( |
| | [], |
| | message, |
| | gr.update(choices=[default_label], value=default_label), |
| | message, |
| | ) |
| |
|
| | valid_paths, summary, label_update = prepare_training_paths( |
| | csv_paths, current_label, cleanup_missing=False |
| | ) |
| |
|
| | info = ( |
| | f"Ready with {len(valid_paths)} CSV file(s) cached locally under " |
| | f"the database cache `{TRAINING_DATA_DIR}`." |
| | ) |
| |
|
| | return valid_paths, summary, label_update, info |
| |
|
| |
|
| | def refresh_remote_browser(force_refresh: bool = False): |
| | if force_refresh: |
| | GITHUB_CONTENT_CACHE.clear() |
| | try: |
| | years = list_remote_years(force_refresh=force_refresh) |
| | if years: |
| | message = "Select a year, month, and day to list available CSV files." |
| | else: |
| | message = ( |
| | "⚠️ No directories were found in the database root. Verify the upstream " |
| | "structure." |
| | ) |
| | return ( |
| | gr.update(choices=years, value=None), |
| | gr.update(choices=[], value=None), |
| | gr.update(choices=[], value=None), |
| | gr.update(choices=[], value=[]), |
| | message, |
| | ) |
| | except Exception as exc: |
| | return ( |
| | gr.update(choices=[], value=None), |
| | gr.update(choices=[], value=None), |
| | gr.update(choices=[], value=None), |
| | gr.update(choices=[], value=[]), |
| | f"⚠️ Failed to query database: {exc}", |
| | ) |
| |
|
| |
|
| | def on_year_change(year: Optional[str]): |
| | if not year: |
| | return ( |
| | gr.update(choices=[], value=None), |
| | gr.update(choices=[], value=None), |
| | gr.update(choices=[], value=[]), |
| | "Select a year to continue.", |
| | ) |
| | try: |
| | months = list_remote_months(year) |
| | message = ( |
| | f"Year `{year}` selected. Choose a month to drill down." |
| | if months |
| | else f"⚠️ No months available under `{year}`." |
| | ) |
| | return ( |
| | gr.update(choices=months, value=None), |
| | gr.update(choices=[], value=None), |
| | gr.update(choices=[], value=[]), |
| | message, |
| | ) |
| | except Exception as exc: |
| | return ( |
| | gr.update(choices=[], value=None), |
| | gr.update(choices=[], value=None), |
| | gr.update(choices=[], value=[]), |
| | f"⚠️ Failed to list months: {exc}", |
| | ) |
| |
|
| |
|
| | def on_month_change(year: Optional[str], month: Optional[str]): |
| | if not year or not month: |
| | return ( |
| | gr.update(choices=[], value=None), |
| | gr.update(choices=[], value=[]), |
| | "Select a month to continue.", |
| | ) |
| | try: |
| | days = list_remote_days(year, month) |
| | message = ( |
| | f"Month `{year}/{month}` ready. Pick a day to view files." |
| | if days |
| | else f"⚠️ No day folders found under `{year}/{month}`." |
| | ) |
| | return ( |
| | gr.update(choices=days, value=None), |
| | gr.update(choices=[], value=[]), |
| | message, |
| | ) |
| | except Exception as exc: |
| | return ( |
| | gr.update(choices=[], value=None), |
| | gr.update(choices=[], value=[]), |
| | f"⚠️ Failed to list days: {exc}", |
| | ) |
| |
|
| |
|
| | def on_day_change(year: Optional[str], month: Optional[str], day: Optional[str]): |
| | if not year or not month or not day: |
| | return ( |
| | gr.update(choices=[], value=[]), |
| | "Select a day to load file names.", |
| | ) |
| | try: |
| | files = list_remote_files(year, month, day) |
| | message = ( |
| | f"{len(files)} file(s) available for `{year}/{month}/{day}`." |
| | if files |
| | else f"⚠️ No CSV files found under `{year}/{month}/{day}`." |
| | ) |
| | return ( |
| | gr.update(choices=files, value=[]), |
| | message, |
| | ) |
| | except Exception as exc: |
| | return ( |
| | gr.update(choices=[], value=[]), |
| | f"⚠️ Failed to list files: {exc}", |
| | ) |
| |
|
| |
|
| | def download_selected_files( |
| | year: Optional[str], |
| | month: Optional[str], |
| | day: Optional[str], |
| | filenames: Sequence[str], |
| | current_label: str, |
| | ): |
| | if not filenames: |
| | message = "Select at least one CSV before downloading." |
| | local = load_repository_training_files(current_label) |
| | return (*local, gr.update(), message) |
| |
|
| | success: List[str] = [] |
| | notes: List[str] = [] |
| | for filename in filenames: |
| | try: |
| | path = download_repository_file( |
| | year or "", month or "", day or "", filename |
| | ) |
| | success.append(str(path)) |
| | except Exception as exc: |
| | notes.append(f"⚠️ {filename}: {exc}") |
| |
|
| | local = load_repository_training_files(current_label) |
| |
|
| | message_lines = [] |
| | if success: |
| | message_lines.append( |
| | f"Downloaded {len(success)} file(s) to the database cache `{TRAINING_DATA_DIR}`." |
| | ) |
| | if notes: |
| | message_lines.extend(notes) |
| | if not message_lines: |
| | message_lines.append("No files were downloaded.") |
| |
|
| | return (*local, gr.update(value=[]), "\n".join(message_lines)) |
| |
|
| |
|
| | def download_day_bundle( |
| | year: Optional[str], |
| | month: Optional[str], |
| | day: Optional[str], |
| | current_label: str, |
| | ): |
| | if not (year and month and day): |
| | local = load_repository_training_files(current_label) |
| | return ( |
| | *local, |
| | gr.update(), |
| | "Select a year, month, and day before downloading an entire day.", |
| | ) |
| |
|
| | try: |
| | files = list_remote_files(year, month, day) |
| | except Exception as exc: |
| | local = load_repository_training_files(current_label) |
| | return ( |
| | *local, |
| | gr.update(), |
| | f"⚠️ Failed to list CSVs for `{year}/{month}/{day}`: {exc}", |
| | ) |
| |
|
| | if not files: |
| | local = load_repository_training_files(current_label) |
| | return ( |
| | *local, |
| | gr.update(), |
| | f"No CSV files were found for `{year}/{month}/{day}`.", |
| | ) |
| |
|
| | result = list(download_selected_files(year, month, day, files, current_label)) |
| | result[-1] = ( |
| | f"Downloaded all {len(files)} CSV file(s) for `{year}/{month}/{day}`.\n" |
| | f"{result[-1]}" |
| | ) |
| | return tuple(result) |
| |
|
| |
|
| | def download_month_bundle( |
| | year: Optional[str], month: Optional[str], current_label: str |
| | ): |
| | if not (year and month): |
| | local = load_repository_training_files(current_label) |
| | return ( |
| | *local, |
| | gr.update(), |
| | "Select a year and month before downloading an entire month.", |
| | ) |
| |
|
| | try: |
| | days = list_remote_days(year, month) |
| | except Exception as exc: |
| | local = load_repository_training_files(current_label) |
| | return ( |
| | *local, |
| | gr.update(), |
| | f"⚠️ Failed to enumerate days for `{year}/{month}`: {exc}", |
| | ) |
| |
|
| | if not days: |
| | local = load_repository_training_files(current_label) |
| | return ( |
| | *local, |
| | gr.update(), |
| | f"No day folders were found for `{year}/{month}`.", |
| | ) |
| |
|
| | downloaded = 0 |
| | notes: List[str] = [] |
| | for day in days: |
| | try: |
| | files = list_remote_files(year, month, day) |
| | except Exception as exc: |
| | notes.append(f"⚠️ Failed to list `{year}/{month}/{day}`: {exc}") |
| | continue |
| | if not files: |
| | notes.append(f"⚠️ No CSV files in `{year}/{month}/{day}`.") |
| | continue |
| | for filename in files: |
| | try: |
| | download_repository_file(year, month, day, filename) |
| | downloaded += 1 |
| | except Exception as exc: |
| | notes.append(f"⚠️ {year}/{month}/{day}/{filename}: {exc}") |
| |
|
| | local = load_repository_training_files(current_label) |
| | message_lines = [] |
| | if downloaded: |
| | message_lines.append( |
| | f"Downloaded {downloaded} CSV file(s) for `{year}/{month}` into the " |
| | f"database cache `{TRAINING_DATA_DIR}`." |
| | ) |
| | message_lines.extend(notes) |
| | if not message_lines: |
| | message_lines.append("No files were downloaded.") |
| |
|
| | return (*local, gr.update(value=[]), "\n".join(message_lines)) |
| |
|
| |
|
| | def download_year_bundle(year: Optional[str], current_label: str): |
| | if not year: |
| | local = load_repository_training_files(current_label) |
| | return ( |
| | *local, |
| | gr.update(), |
| | "Select a year before downloading an entire year of CSVs.", |
| | ) |
| |
|
| | try: |
| | months = list_remote_months(year) |
| | except Exception as exc: |
| | local = load_repository_training_files(current_label) |
| | return ( |
| | *local, |
| | gr.update(), |
| | f"⚠️ Failed to enumerate months for `{year}`: {exc}", |
| | ) |
| |
|
| | if not months: |
| | local = load_repository_training_files(current_label) |
| | return ( |
| | *local, |
| | gr.update(), |
| | f"No month folders were found for `{year}`.", |
| | ) |
| |
|
| | downloaded = 0 |
| | notes: List[str] = [] |
| | for month in months: |
| | try: |
| | days = list_remote_days(year, month) |
| | except Exception as exc: |
| | notes.append(f"⚠️ Failed to list `{year}/{month}`: {exc}") |
| | continue |
| | if not days: |
| | notes.append(f"⚠️ No day folders in `{year}/{month}`.") |
| | continue |
| | for day in days: |
| | try: |
| | files = list_remote_files(year, month, day) |
| | except Exception as exc: |
| | notes.append(f"⚠️ Failed to list `{year}/{month}/{day}`: {exc}") |
| | continue |
| | if not files: |
| | notes.append(f"⚠️ No CSV files in `{year}/{month}/{day}`.") |
| | continue |
| | for filename in files: |
| | try: |
| | download_repository_file(year, month, day, filename) |
| | downloaded += 1 |
| | except Exception as exc: |
| | notes.append(f"⚠️ {year}/{month}/{day}/{filename}: {exc}") |
| |
|
| | local = load_repository_training_files(current_label) |
| | message_lines = [] |
| | if downloaded: |
| | message_lines.append( |
| | f"Downloaded {downloaded} CSV file(s) for `{year}` into the " |
| | f"database cache `{TRAINING_DATA_DIR}`." |
| | ) |
| | message_lines.extend(notes) |
| | if not message_lines: |
| | message_lines.append("No files were downloaded.") |
| |
|
| | return (*local, gr.update(value=[]), "\n".join(message_lines)) |
| |
|
| |
|
| | def clear_downloaded_cache(current_label: str): |
| | status_message = "" |
| | try: |
| | if TRAINING_DATA_DIR.exists(): |
| | shutil.rmtree(TRAINING_DATA_DIR) |
| | TRAINING_DATA_DIR.mkdir(parents=True, exist_ok=True) |
| | status_message = ( |
| | f"Cleared all downloaded CSVs from database cache `{TRAINING_DATA_DIR}`." |
| | ) |
| | except Exception as exc: |
| | status_message = f"⚠️ Failed to clear database cache: {exc}" |
| |
|
| | local = load_repository_training_files(current_label, force_refresh=True) |
| | remote = list(refresh_remote_browser(force_refresh=False)) |
| | if status_message: |
| | previous = remote[-1] |
| | if isinstance(previous, str) and previous: |
| | remote[-1] = f"{status_message}\n{previous}" |
| | else: |
| | remote[-1] = status_message |
| |
|
| | return (*local, *remote) |
| |
|
| |
|
| | def normalise_output_directory(directory: Optional[str]) -> Path: |
| | base = Path(directory or MODEL_OUTPUT_DIR) |
| | base = base.expanduser() |
| | if not base.is_absolute(): |
| | base = (Path.cwd() / base).resolve() |
| | return base |
| |
|
| |
|
| | def resolve_output_path( |
| | directory: Optional[Union[Path, str]], filename: Optional[str], fallback: str |
| | ) -> Path: |
| | if isinstance(directory, Path): |
| | base = directory |
| | else: |
| | base = normalise_output_directory(directory) |
| | candidate = Path(filename or "").expanduser() |
| | if str(candidate): |
| | if candidate.is_absolute(): |
| | return candidate |
| | return (base / candidate).resolve() |
| | return (base / fallback).resolve() |
| |
|
| |
|
| | ARTIFACT_FILE_EXTENSIONS: Tuple[str, ...] = ( |
| | ".keras", |
| | ".h5", |
| | ".joblib", |
| | ".pkl", |
| | ".json", |
| | ".onnx", |
| | ".zip", |
| | ".txt", |
| | ) |
| |
|
| |
|
| | def gather_directory_choices(current: Optional[str]) -> Tuple[List[str], str]: |
| | base = normalise_output_directory(current or str(MODEL_OUTPUT_DIR)) |
| | candidates = {str(base)} |
| | try: |
| | for candidate in base.parent.iterdir(): |
| | if candidate.is_dir(): |
| | candidates.add(str(candidate.resolve())) |
| | except Exception: |
| | pass |
| | return sorted(candidates), str(base) |
| |
|
| |
|
| | def gather_artifact_choices( |
| | directory: Optional[str], selection: Optional[str] = None |
| | ) -> Tuple[List[Tuple[str, str]], Optional[str]]: |
| | base = normalise_output_directory(directory) |
| | choices: List[Tuple[str, str]] = [] |
| | selected_value: Optional[str] = None |
| | if base.exists(): |
| | try: |
| | artifacts = sorted( |
| | [ |
| | path |
| | for path in base.iterdir() |
| | if path.is_file() |
| | and ( |
| | not ARTIFACT_FILE_EXTENSIONS |
| | or path.suffix.lower() in ARTIFACT_FILE_EXTENSIONS |
| | ) |
| | ], |
| | key=lambda path: path.name.lower(), |
| | ) |
| | choices = [(artifact.name, str(artifact)) for artifact in artifacts] |
| | except Exception: |
| | choices = [] |
| |
|
| | if selection and any(value == selection for _, value in choices): |
| | selected_value = selection |
| | elif choices: |
| | selected_value = choices[0][1] |
| |
|
| | return choices, selected_value |
| |
|
| |
|
| | def download_button_state(path: Optional[Union[str, Path]]): |
| | if not path: |
| | return gr.update(value=None, visible=False) |
| | candidate = Path(path) |
| | if candidate.exists(): |
| | return gr.update(value=str(candidate), visible=True) |
| | return gr.update(value=None, visible=False) |
| |
|
| |
|
| | def clear_training_files(): |
| | default_label = LABEL_COLUMN or "Fault" |
| | for cached_file in TRAINING_UPLOAD_DIR.glob("*"): |
| | try: |
| | if cached_file.is_file(): |
| | cached_file.unlink(missing_ok=True) |
| | except Exception: |
| | pass |
| | return ( |
| | [], |
| | "No training files selected.", |
| | gr.update(choices=[default_label], value=default_label), |
| | gr.update(value=None), |
| | ) |
| |
|
| |
|
| | PROJECT_OVERVIEW_MD = """ |
| | ## Project Overview |
| | |
| | This project focuses on classifying faults in electrical transmission lines and |
| | grid-connected photovoltaic (PV) systems by combining ensemble learning |
| | techniques with deep neural architectures. |
| | |
| | ## Datasets |
| | |
| | ### Transmission Line Fault Dataset |
| | - 134,406 samples collected from Phasor Measurement Units (PMUs) |
| | - 14 monitored channels covering currents, voltages, magnitudes, frequency, and phase angles |
| | - Labels span symmetrical and asymmetrical faults: NF, L-G, LL, LL-G, LLL, and LLL-G |
| | - Time span: 0 to 5.7 seconds with high-frequency sampling |
| | |
| | ### Grid-Connected PV System Fault Dataset |
| | - 2,163,480 samples from 16 experimental scenarios |
| | - 14 features including PV array measurements (Ipv, Vpv, Vdc), three-phase currents/voltages, aggregate magnitudes (Iabc, Vabc), and frequency indicators (If, Vf) |
| | - Captures array, inverter, grid anomaly, feedback sensor, and MPPT controller faults at 9.9989 μs sampling intervals |
| | |
| | ## Data Format Quick Reference |
| | |
| | Each measurement file may be comma or tab separated and typically exposes the |
| | following ordered columns: |
| | |
| | 1. `Timestamp` |
| | 2. `[325] UPMU_SUB22:FREQ` – system frequency (Hz) |
| | 3. `[326] UPMU_SUB22:DFDT` – frequency rate-of-change |
| | 4. `[327] UPMU_SUB22:FLAG` – PMU status flag |
| | 5. `[328] UPMU_SUB22-L1:MAG` – phase A voltage magnitude |
| | 6. `[329] UPMU_SUB22-L1:ANG` – phase A voltage angle |
| | 7. `[330] UPMU_SUB22-L2:MAG` – phase B voltage magnitude |
| | 8. `[331] UPMU_SUB22-L2:ANG` – phase B voltage angle |
| | 9. `[332] UPMU_SUB22-L3:MAG` – phase C voltage magnitude |
| | 10. `[333] UPMU_SUB22-L3:ANG` – phase C voltage angle |
| | 11. `[334] UPMU_SUB22-C1:MAG` – phase A current magnitude |
| | 12. `[335] UPMU_SUB22-C1:ANG` – phase A current angle |
| | 13. `[336] UPMU_SUB22-C2:MAG` – phase B current magnitude |
| | 14. `[337] UPMU_SUB22-C2:ANG` – phase B current angle |
| | 15. `[338] UPMU_SUB22-C3:MAG` – phase C current magnitude |
| | 16. `[339] UPMU_SUB22-C3:ANG` – phase C current angle |
| | |
| | The training tab automatically downloads the latest CSV exports from the |
| | `VincentCroft/ThesisModelData` repository and concatenates them before building |
| | sliding windows. |
| | |
| | ## Models Developed |
| | |
| | 1. **Support Vector Machine (SVM)** – provides the classical machine learning baseline with balanced accuracy across both datasets (85% PMU / 83% PV). |
| | 2. **CNN-LSTM** – couples convolutional feature extraction with temporal memory, achieving 92% PMU / 89% PV accuracy. |
| | 3. **Temporal Convolutional Network (TCN)** – leverages dilated convolutions for long-range context and delivers the best trade-off between accuracy and training time (94% PMU / 91% PV). |
| | |
| | ## Results Summary |
| | |
| | - **Transmission Line Fault Classification**: SVM 85%, CNN-LSTM 92%, TCN 94% |
| | - **PV System Fault Classification**: SVM 83%, CNN-LSTM 89%, TCN 91% |
| | |
| | Use the **Inference** tab to score new PMU/PV windows and the **Training** tab to |
| | fine-tune or retrain any of the supported models directly within Hugging Face |
| | Spaces. The logs panel will surface TensorBoard archives whenever deep-learning |
| | models are trained. |
| | """ |
| |
|
| |
|
| | def load_measurement_csv(path: str) -> pd.DataFrame: |
| | """Read a PMU/PV measurement file with flexible separators and column mapping.""" |
| |
|
| | try: |
| | df = pd.read_csv(path, sep=None, engine="python", encoding="utf-8-sig") |
| | except Exception: |
| | df = None |
| | for separator in ("\t", ",", ";"): |
| | try: |
| | df = pd.read_csv( |
| | path, sep=separator, engine="python", encoding="utf-8-sig" |
| | ) |
| | break |
| | except Exception: |
| | df = None |
| | if df is None: |
| | raise |
| |
|
| | |
| | df.columns = [str(col).strip() for col in df.columns] |
| |
|
| | print(f"Loaded CSV with {len(df)} rows and {len(df.columns)} columns") |
| | print(f"Columns: {list(df.columns)}") |
| | print(f"Data shape: {df.shape}") |
| |
|
| | |
| | if len(df) < 100: |
| | print( |
| | f"Warning: Only {len(df)} rows of data. Recommend at least 1000 rows for effective training." |
| | ) |
| |
|
| | |
| | has_label = any( |
| | col.lower() in ["fault", "label", "class", "target"] for col in df.columns |
| | ) |
| | if not has_label: |
| | print( |
| | "Warning: No label column found. Adding dummy 'Fault' column with value 'Normal' for all samples." |
| | ) |
| | df["Fault"] = "Normal" |
| |
|
| | |
| | column_mapping = {} |
| | expected_cols = list(REQUIRED_PMU_COLUMNS) |
| |
|
| | |
| | if "Timestamp" in df.columns: |
| | numeric_cols = [col for col in df.columns if col != "Timestamp"] |
| | if len(numeric_cols) >= len(expected_cols): |
| | |
| | for i, expected_col in enumerate(expected_cols): |
| | if i < len(numeric_cols): |
| | column_mapping[numeric_cols[i]] = expected_col |
| |
|
| | |
| | df = df.rename(columns=column_mapping) |
| |
|
| | |
| | missing = [col for col in REQUIRED_PMU_COLUMNS if col not in df.columns] |
| | if missing: |
| | |
| | available_numeric = df.select_dtypes(include=[np.number]).columns.tolist() |
| | if len(available_numeric) >= len(expected_cols): |
| | |
| | for i, expected_col in enumerate(expected_cols): |
| | if i < len(available_numeric): |
| | if available_numeric[i] not in df.columns: |
| | continue |
| | df = df.rename(columns={available_numeric[i]: expected_col}) |
| |
|
| | |
| | missing = [col for col in REQUIRED_PMU_COLUMNS if col not in df.columns] |
| |
|
| | if missing: |
| | missing_str = ", ".join(missing) |
| | available_str = ", ".join(df.columns.tolist()) |
| | raise ValueError( |
| | f"Missing required PMU feature columns: {missing_str}. " |
| | f"Available columns: {available_str}. " |
| | "Please ensure your CSV has the correct format with Timestamp followed by PMU measurements." |
| | ) |
| |
|
| | return df |
| |
|
| |
|
| | def apply_metadata(metadata: Dict[str, Any]) -> None: |
| | global FEATURE_COLUMNS, LABEL_CLASSES, LABEL_COLUMN, SEQUENCE_LENGTH, DEFAULT_WINDOW_STRIDE, MODEL_TYPE, MODEL_FORMAT |
| | FEATURE_COLUMNS = [ |
| | str(col) for col in metadata.get("feature_columns", DEFAULT_FEATURE_COLUMNS) |
| | ] |
| | LABEL_CLASSES = [str(label) for label in metadata.get("label_classes", [])] |
| | LABEL_COLUMN = str(metadata.get("label_column", "Fault")) |
| | SEQUENCE_LENGTH = int(metadata.get("sequence_length", DEFAULT_SEQUENCE_LENGTH)) |
| | DEFAULT_WINDOW_STRIDE = int(metadata.get("stride", DEFAULT_STRIDE)) |
| | MODEL_TYPE = str(metadata.get("model_type", "cnn_lstm")).lower() |
| | MODEL_FORMAT = str( |
| | metadata.get("model_format", "joblib" if MODEL_TYPE == "svm" else "keras") |
| | ).lower() |
| |
|
| |
|
| | apply_metadata(METADATA) |
| |
|
| |
|
| | def sync_label_classes_from_model(model: Optional[object]) -> None: |
| | global LABEL_CLASSES |
| | if model is None: |
| | return |
| | if hasattr(model, "classes_"): |
| | LABEL_CLASSES = [str(label) for label in getattr(model, "classes_")] |
| | elif not LABEL_CLASSES and hasattr(model, "output_shape"): |
| | LABEL_CLASSES = [str(i) for i in range(int(model.output_shape[-1]))] |
| |
|
| |
|
| | |
| | print("Loading model and scaler...") |
| | try: |
| | MODEL = try_load_model(MODEL_PATH, MODEL_TYPE, MODEL_FORMAT) |
| | print(f"Model loaded: {MODEL is not None}") |
| | except Exception as e: |
| | print(f"Model loading failed: {e}") |
| | MODEL = None |
| |
|
| | try: |
| | SCALER = try_load_scaler(SCALER_PATH) |
| | print(f"Scaler loaded: {SCALER is not None}") |
| | except Exception as e: |
| | print(f"Scaler loading failed: {e}") |
| | SCALER = None |
| |
|
| | try: |
| | sync_label_classes_from_model(MODEL) |
| | print("Label classes synchronized") |
| | except Exception as e: |
| | print(f"Label sync failed: {e}") |
| |
|
| | print("Application initialization completed.") |
| | print( |
| | f"Ready to start Gradio interface. Model available: {MODEL is not None}, Scaler available: {SCALER is not None}" |
| | ) |
| |
|
| |
|
| | def refresh_artifacts(model_path: Path, scaler_path: Path, metadata_path: Path) -> None: |
| | global MODEL_PATH, SCALER_PATH, METADATA_PATH, MODEL, SCALER, METADATA |
| | MODEL_PATH = model_path |
| | SCALER_PATH = scaler_path |
| | METADATA_PATH = metadata_path |
| | METADATA = load_metadata(metadata_path) |
| | apply_metadata(METADATA) |
| | MODEL = try_load_model(model_path, MODEL_TYPE, MODEL_FORMAT) |
| | SCALER = try_load_scaler(scaler_path) |
| | sync_label_classes_from_model(MODEL) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | def ensure_ready(): |
| | if MODEL is None or SCALER is None: |
| | raise RuntimeError( |
| | "The model and feature scaler are not available. Upload the trained model " |
| | "(for example `pmu_cnn_lstm_model.keras`, `pmu_tcn_model.keras`, or `pmu_svm_model.joblib`), " |
| | "the feature scaler (`pmu_feature_scaler.pkl`), and the metadata JSON (`pmu_metadata.json`) to the Space root " |
| | "or configure the Hugging Face Hub environment variables so the artifacts can be downloaded " |
| | "automatically." |
| | ) |
| |
|
| |
|
| | def parse_text_features(text: str) -> np.ndarray: |
| | cleaned = re.sub(r"[;\n\t]+", ",", text.strip()) |
| | arr = np.fromstring(cleaned, sep=",") |
| | if arr.size == 0: |
| | raise ValueError( |
| | "No feature values were parsed. Please enter comma-separated numbers." |
| | ) |
| | return arr.astype(np.float32) |
| |
|
| |
|
| | def apply_scaler(sequences: np.ndarray) -> np.ndarray: |
| | if SCALER is None: |
| | return sequences |
| | shape = sequences.shape |
| | flattened = sequences.reshape(-1, shape[-1]) |
| | scaled = SCALER.transform(flattened) |
| | return scaled.reshape(shape) |
| |
|
| |
|
| | def make_sliding_windows( |
| | data: np.ndarray, sequence_length: int, stride: int |
| | ) -> np.ndarray: |
| | if data.shape[0] < sequence_length: |
| | raise ValueError( |
| | f"The dataset contains {data.shape[0]} rows which is less than the requested sequence " |
| | f"length {sequence_length}. Provide more samples or reduce the sequence length." |
| | ) |
| | windows = [ |
| | data[start : start + sequence_length] |
| | for start in range(0, data.shape[0] - sequence_length + 1, stride) |
| | ] |
| | return np.stack(windows) |
| |
|
| |
|
| | def dataframe_to_sequences( |
| | df: pd.DataFrame, |
| | *, |
| | sequence_length: int, |
| | stride: int, |
| | feature_columns: Sequence[str], |
| | drop_label: bool = True, |
| | ) -> np.ndarray: |
| | work_df = df.copy() |
| | if drop_label and LABEL_COLUMN in work_df.columns: |
| | work_df = work_df.drop(columns=[LABEL_COLUMN]) |
| | if "Timestamp" in work_df.columns: |
| | work_df = work_df.sort_values("Timestamp") |
| |
|
| | available_cols = [c for c in feature_columns if c in work_df.columns] |
| | n_features = len(feature_columns) |
| | if available_cols and len(available_cols) == n_features: |
| | array = work_df[available_cols].astype(np.float32).to_numpy() |
| | return make_sliding_windows(array, sequence_length, stride) |
| |
|
| | numeric_df = work_df.select_dtypes(include=[np.number]) |
| | array = numeric_df.astype(np.float32).to_numpy() |
| | if array.shape[1] == n_features * sequence_length: |
| | return array.reshape(array.shape[0], sequence_length, n_features) |
| | if sequence_length == 1 and array.shape[1] == n_features: |
| | return array.reshape(array.shape[0], 1, n_features) |
| | raise ValueError( |
| | "CSV columns do not match the expected feature layout. Include the full PMU feature set " |
| | "or provide pre-shaped sliding window data." |
| | ) |
| |
|
| |
|
| | def label_name(index: int) -> str: |
| | if 0 <= index < len(LABEL_CLASSES): |
| | return str(LABEL_CLASSES[index]) |
| | return f"class_{index}" |
| |
|
| |
|
| | def format_predictions(probabilities: np.ndarray) -> pd.DataFrame: |
| | rows: List[Dict[str, object]] = [] |
| | order = np.argsort(probabilities, axis=1)[:, ::-1] |
| | for idx, (prob_row, ranking) in enumerate(zip(probabilities, order)): |
| | top_idx = int(ranking[0]) |
| | top_label = label_name(top_idx) |
| | top_conf = float(prob_row[top_idx]) |
| | top3 = [f"{label_name(i)} ({prob_row[i]*100:.2f}%)" for i in ranking[:3]] |
| | rows.append( |
| | { |
| | "window": idx, |
| | "predicted_label": top_label, |
| | "confidence": round(top_conf, 4), |
| | "top3": " | ".join(top3), |
| | } |
| | ) |
| | return pd.DataFrame(rows) |
| |
|
| |
|
| | def probabilities_to_json(probabilities: np.ndarray) -> List[Dict[str, object]]: |
| | payload: List[Dict[str, object]] = [] |
| | for idx, prob_row in enumerate(probabilities): |
| | payload.append( |
| | { |
| | "window": int(idx), |
| | "probabilities": { |
| | label_name(i): float(prob_row[i]) for i in range(prob_row.shape[0]) |
| | }, |
| | } |
| | ) |
| | return payload |
| |
|
| |
|
| | def predict_sequences( |
| | sequences: np.ndarray, |
| | ) -> Tuple[str, pd.DataFrame, List[Dict[str, object]]]: |
| | ensure_ready() |
| | sequences = apply_scaler(sequences.astype(np.float32)) |
| | if MODEL_TYPE == "svm": |
| | flattened = sequences.reshape(sequences.shape[0], -1) |
| | if hasattr(MODEL, "predict_proba"): |
| | probs = MODEL.predict_proba(flattened) |
| | else: |
| | raise RuntimeError( |
| | "Loaded SVM model does not expose predict_proba. Retrain with probability=True." |
| | ) |
| | else: |
| | probs = MODEL.predict(sequences, verbose=0) |
| | table = format_predictions(probs) |
| | json_probs = probabilities_to_json(probs) |
| | architecture = MODEL_TYPE.replace("_", "-").upper() |
| | status = f"Generated {len(sequences)} windows. {architecture} model output dimension: {probs.shape[1]}." |
| | return status, table, json_probs |
| |
|
| |
|
| | def predict_from_text( |
| | text: str, sequence_length: int |
| | ) -> Tuple[str, pd.DataFrame, List[Dict[str, object]]]: |
| | arr = parse_text_features(text) |
| | n_features = len(FEATURE_COLUMNS) |
| | if arr.size % n_features != 0: |
| | raise ValueError( |
| | f"The number of values ({arr.size}) is not a multiple of the feature dimension " |
| | f"({n_features}). Provide values in groups of {n_features}." |
| | ) |
| | timesteps = arr.size // n_features |
| | if timesteps != sequence_length: |
| | raise ValueError( |
| | f"Detected {timesteps} timesteps which does not match the configured sequence length " |
| | f"({sequence_length})." |
| | ) |
| | sequences = arr.reshape(1, sequence_length, n_features) |
| | status, table, probs = predict_sequences(sequences) |
| | status = f"Single window prediction complete. {status}" |
| | return status, table, probs |
| |
|
| |
|
| | def predict_from_csv( |
| | file_obj, sequence_length: int, stride: int |
| | ) -> Tuple[str, pd.DataFrame, List[Dict[str, object]]]: |
| | df = load_measurement_csv(file_obj.name) |
| | sequences = dataframe_to_sequences( |
| | df, |
| | sequence_length=sequence_length, |
| | stride=stride, |
| | feature_columns=FEATURE_COLUMNS, |
| | ) |
| | status, table, probs = predict_sequences(sequences) |
| | status = f"CSV processed successfully. Generated {len(sequences)} windows. {status}" |
| | return status, table, probs |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | def classification_report_to_dataframe(report: Dict[str, Any]) -> pd.DataFrame: |
| | rows: List[Dict[str, Any]] = [] |
| | for label, metrics in report.items(): |
| | if isinstance(metrics, dict): |
| | row = {"label": label} |
| | for key, value in metrics.items(): |
| | if key == "support": |
| | row[key] = int(value) |
| | else: |
| | row[key] = round(float(value), 4) |
| | rows.append(row) |
| | else: |
| | rows.append({"label": label, "accuracy": round(float(metrics), 4)}) |
| | return pd.DataFrame(rows) |
| |
|
| |
|
| | def confusion_matrix_to_dataframe( |
| | confusion: Sequence[Sequence[float]], labels: Sequence[str] |
| | ) -> pd.DataFrame: |
| | if not confusion: |
| | return pd.DataFrame() |
| | df = pd.DataFrame(confusion, index=list(labels), columns=list(labels)) |
| | df.index.name = "True Label" |
| | df.columns.name = "Predicted Label" |
| | return df |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | def build_interface() -> gr.Blocks: |
| | theme = gr.themes.Soft( |
| | primary_hue="sky", secondary_hue="blue", neutral_hue="gray" |
| | ).set( |
| | body_background_fill="#1f1f1f", |
| | body_text_color="#f5f5f5", |
| | block_background_fill="#262626", |
| | block_border_color="#333333", |
| | button_primary_background_fill="#5ac8fa", |
| | button_primary_background_fill_hover="#48b5eb", |
| | button_primary_border_color="#38bdf8", |
| | button_primary_text_color="#0f172a", |
| | button_secondary_background_fill="#3f3f46", |
| | button_secondary_text_color="#f5f5f5", |
| | ) |
| |
|
| | def _normalise_directory_string(value: Optional[Union[str, Path]]) -> str: |
| | if value is None: |
| | return "" |
| | path = Path(value).expanduser() |
| | try: |
| | return str(path.resolve()) |
| | except Exception: |
| | return str(path) |
| |
|
| | with gr.Blocks( |
| | title="Fault Classification - PMU Data", theme=theme, css=APP_CSS |
| | ) as demo: |
| | gr.Markdown("# Fault Classification for PMU & PV Data") |
| | gr.Markdown( |
| | "🖥️ TensorFlow is locked to CPU execution so the Space can run without CUDA drivers." |
| | ) |
| | if MODEL is None or SCALER is None: |
| | gr.Markdown( |
| | "⚠️ **Artifacts Missing** — Upload `pmu_cnn_lstm_model.keras`, " |
| | "`pmu_feature_scaler.pkl`, and `pmu_metadata.json` to enable inference, " |
| | "or configure the Hugging Face Hub environment variables so they can be downloaded." |
| | ) |
| | else: |
| | class_count = len(LABEL_CLASSES) if LABEL_CLASSES else "unknown" |
| | gr.Markdown( |
| | f"Loaded a **{MODEL_TYPE.upper()}** model ({MODEL_FORMAT.upper()}) with " |
| | f"{len(FEATURE_COLUMNS)} features, sequence length **{SEQUENCE_LENGTH}**, and " |
| | f"{class_count} target classes. Use the tabs below to run inference or fine-tune " |
| | "the model with your own CSV files." |
| | ) |
| |
|
| | with gr.Accordion("Feature Reference", open=False): |
| | gr.Markdown( |
| | f"Each time window expects **{len(FEATURE_COLUMNS)} features** ordered as follows:\n" |
| | + "\n".join(f"- {name}" for name in FEATURE_COLUMNS) |
| | ) |
| | gr.Markdown( |
| | f"Default training parameters: **sequence length = {SEQUENCE_LENGTH}**, " |
| | f"**stride = {DEFAULT_WINDOW_STRIDE}**. Adjust them in the tabs as needed." |
| | ) |
| |
|
| | with gr.Tabs(): |
| | with gr.Tab("Overview"): |
| | gr.Markdown(PROJECT_OVERVIEW_MD) |
| | with gr.Tab("Inference"): |
| | gr.Markdown("## Run Inference") |
| | with gr.Row(): |
| | file_in = gr.File(label="Upload PMU CSV", file_types=[".csv"]) |
| | text_in = gr.Textbox( |
| | lines=4, |
| | label="Or paste a single window (comma separated)", |
| | placeholder="49.97772,1.215825E-38,...", |
| | ) |
| |
|
| | with gr.Row(): |
| | sequence_length_input = gr.Slider( |
| | minimum=1, |
| | maximum=max(1, SEQUENCE_LENGTH * 2), |
| | step=1, |
| | value=SEQUENCE_LENGTH, |
| | label="Sequence length (timesteps)", |
| | ) |
| | stride_input = gr.Slider( |
| | minimum=1, |
| | maximum=max(1, SEQUENCE_LENGTH), |
| | step=1, |
| | value=max(1, DEFAULT_WINDOW_STRIDE), |
| | label="CSV window stride", |
| | ) |
| |
|
| | predict_btn = gr.Button("🚀 Run Inference", variant="primary") |
| | status_out = gr.Textbox(label="Status", interactive=False) |
| | table_out = gr.Dataframe( |
| | headers=["window", "predicted_label", "confidence", "top3"], |
| | label="Predictions", |
| | interactive=False, |
| | ) |
| | probs_out = gr.JSON(label="Per-window probabilities") |
| |
|
| | def _run_prediction(file_obj, text, sequence_length, stride): |
| | sequence_length = int(sequence_length) |
| | stride = int(stride) |
| | try: |
| | if file_obj is not None: |
| | return predict_from_csv(file_obj, sequence_length, stride) |
| | if text and text.strip(): |
| | return predict_from_text(text, sequence_length) |
| | return ( |
| | "Please upload a CSV file or provide feature values.", |
| | pd.DataFrame(), |
| | [], |
| | ) |
| | except Exception as exc: |
| | return f"Prediction failed: {exc}", pd.DataFrame(), [] |
| |
|
| | predict_btn.click( |
| | _run_prediction, |
| | inputs=[file_in, text_in, sequence_length_input, stride_input], |
| | outputs=[status_out, table_out, probs_out], |
| | concurrency_limit=EVENT_CONCURRENCY_LIMIT, |
| | ) |
| |
|
| | with gr.Tab("Training"): |
| | gr.Markdown("## Train or Fine-tune the Model") |
| | gr.Markdown( |
| | "Training data is automatically downloaded from the database. " |
| | "Refresh the cache if new files are added upstream." |
| | ) |
| |
|
| | training_files_state = gr.State([]) |
| | with gr.Row(): |
| | with gr.Column(scale=3): |
| | training_files_summary = gr.Textbox( |
| | label="Database training CSVs", |
| | value="Training dataset not loaded yet.", |
| | lines=4, |
| | interactive=False, |
| | elem_id="training-files-summary", |
| | ) |
| | with gr.Column(scale=2, min_width=240): |
| | dataset_info = gr.Markdown( |
| | "No local database CSVs downloaded yet.", |
| | ) |
| | dataset_refresh = gr.Button( |
| | "🔄 Reload dataset from database", |
| | variant="secondary", |
| | ) |
| | clear_cache_button = gr.Button( |
| | "🧹 Clear downloaded cache", |
| | variant="secondary", |
| | ) |
| |
|
| | with gr.Accordion("📂 DataBaseBrowser", open=False): |
| | gr.Markdown( |
| | "Browse the upstream database by date and download only the CSVs you need." |
| | ) |
| | with gr.Row(elem_id="date-browser-row"): |
| | with gr.Column(scale=1, elem_classes=["date-browser-column"]): |
| | year_selector = gr.Dropdown(label="Year", choices=[]) |
| | year_download_button = gr.Button( |
| | "⬇️ Download year CSVs", variant="secondary" |
| | ) |
| | with gr.Column(scale=1, elem_classes=["date-browser-column"]): |
| | month_selector = gr.Dropdown(label="Month", choices=[]) |
| | month_download_button = gr.Button( |
| | "⬇️ Download month CSVs", variant="secondary" |
| | ) |
| | with gr.Column(scale=1, elem_classes=["date-browser-column"]): |
| | day_selector = gr.Dropdown(label="Day", choices=[]) |
| | day_download_button = gr.Button( |
| | "⬇️ Download day CSVs", variant="secondary" |
| | ) |
| | with gr.Column(elem_id="available-files-section"): |
| | available_files = gr.CheckboxGroup( |
| | label="Available CSV files", |
| | choices=[], |
| | value=[], |
| | elem_id="available-files-grid", |
| | ) |
| | download_button = gr.Button( |
| | "⬇️ Download selected CSVs", |
| | variant="secondary", |
| | elem_id="download-selected-button", |
| | ) |
| | repo_status = gr.Markdown( |
| | "Click 'Reload dataset from database' to fetch the directory tree." |
| | ) |
| |
|
| | with gr.Row(): |
| | label_input = gr.Dropdown( |
| | value=LABEL_COLUMN, |
| | choices=[LABEL_COLUMN], |
| | allow_custom_value=True, |
| | label="Label column name", |
| | ) |
| | model_selector = gr.Radio( |
| | choices=["CNN-LSTM", "TCN", "SVM"], |
| | value=( |
| | "TCN" |
| | if MODEL_TYPE == "tcn" |
| | else ("SVM" if MODEL_TYPE == "svm" else "CNN-LSTM") |
| | ), |
| | label="Model architecture", |
| | ) |
| | sequence_length_train = gr.Slider( |
| | minimum=4, |
| | maximum=max(32, SEQUENCE_LENGTH * 2), |
| | step=1, |
| | value=SEQUENCE_LENGTH, |
| | label="Sequence length", |
| | ) |
| | stride_train = gr.Slider( |
| | minimum=1, |
| | maximum=max(32, SEQUENCE_LENGTH * 2), |
| | step=1, |
| | value=max(1, DEFAULT_WINDOW_STRIDE), |
| | label="Stride", |
| | ) |
| |
|
| | model_default = MODEL_FILENAME_BY_TYPE.get( |
| | MODEL_TYPE, Path(LOCAL_MODEL_FILE).name |
| | ) |
| |
|
| | with gr.Row(): |
| | validation_train = gr.Slider( |
| | minimum=0.05, |
| | maximum=0.4, |
| | step=0.05, |
| | value=0.2, |
| | label="Validation split", |
| | ) |
| | batch_train = gr.Slider( |
| | minimum=32, |
| | maximum=512, |
| | step=32, |
| | value=128, |
| | label="Batch size", |
| | ) |
| | epochs_train = gr.Slider( |
| | minimum=5, |
| | maximum=100, |
| | step=5, |
| | value=50, |
| | label="Epochs", |
| | ) |
| |
|
| | directory_choices, directory_default = gather_directory_choices( |
| | str(MODEL_OUTPUT_DIR) |
| | ) |
| | artifact_choices, default_artifact = gather_artifact_choices( |
| | directory_default |
| | ) |
| |
|
| | with gr.Row(): |
| | output_directory = gr.Dropdown( |
| | value=directory_default, |
| | label="Output directory", |
| | choices=directory_choices, |
| | allow_custom_value=True, |
| | ) |
| | model_name = gr.Textbox( |
| | value=model_default, |
| | label="Model output filename", |
| | ) |
| | scaler_name = gr.Textbox( |
| | value=Path(LOCAL_SCALER_FILE).name, |
| | label="Scaler output filename", |
| | ) |
| | metadata_name = gr.Textbox( |
| | value=Path(LOCAL_METADATA_FILE).name, |
| | label="Metadata output filename", |
| | ) |
| |
|
| | with gr.Row(): |
| | artifact_browser = gr.Dropdown( |
| | label="Saved artifacts in directory", |
| | choices=artifact_choices, |
| | value=default_artifact, |
| | ) |
| | artifact_download_button = gr.DownloadButton( |
| | "⬇️ Download selected artifact", |
| | value=default_artifact, |
| | visible=bool(default_artifact), |
| | variant="secondary", |
| | ) |
| |
|
| | def on_output_directory_change(selected_dir, current_selection): |
| | choices, normalised = gather_directory_choices(selected_dir) |
| | artifact_options, selected = gather_artifact_choices( |
| | normalised, current_selection |
| | ) |
| | return ( |
| | gr.update(choices=choices, value=normalised), |
| | gr.update(choices=artifact_options, value=selected), |
| | download_button_state(selected), |
| | ) |
| |
|
| | def on_artifact_change(selected_path): |
| | return download_button_state(selected_path) |
| |
|
| | output_directory.change( |
| | on_output_directory_change, |
| | inputs=[output_directory, artifact_browser], |
| | outputs=[ |
| | output_directory, |
| | artifact_browser, |
| | artifact_download_button, |
| | ], |
| | concurrency_limit=EVENT_CONCURRENCY_LIMIT, |
| | ) |
| |
|
| | artifact_browser.change( |
| | on_artifact_change, |
| | inputs=[artifact_browser], |
| | outputs=[artifact_download_button], |
| | concurrency_limit=EVENT_CONCURRENCY_LIMIT, |
| | ) |
| |
|
| | with gr.Row(elem_id="artifact-download-row"): |
| | model_download_button = gr.DownloadButton( |
| | "⬇️ Download model file", |
| | value=None, |
| | visible=False, |
| | elem_classes=["artifact-download-button"], |
| | ) |
| | scaler_download_button = gr.DownloadButton( |
| | "⬇️ Download scaler file", |
| | value=None, |
| | visible=False, |
| | elem_classes=["artifact-download-button"], |
| | ) |
| | metadata_download_button = gr.DownloadButton( |
| | "⬇️ Download metadata file", |
| | value=None, |
| | visible=False, |
| | elem_classes=["artifact-download-button"], |
| | ) |
| | tensorboard_download_button = gr.DownloadButton( |
| | "⬇️ Download TensorBoard logs", |
| | value=None, |
| | visible=False, |
| | elem_classes=["artifact-download-button"], |
| | ) |
| |
|
| | model_download_button.file_name = Path(LOCAL_MODEL_FILE).name |
| | scaler_download_button.file_name = Path(LOCAL_SCALER_FILE).name |
| | metadata_download_button.file_name = Path(LOCAL_METADATA_FILE).name |
| | tensorboard_download_button.file_name = "tensorboard_logs.zip" |
| |
|
| | tensorboard_toggle = gr.Checkbox( |
| | value=True, |
| | label="Enable TensorBoard logging (creates downloadable archive)", |
| | ) |
| |
|
| | def _suggest_model_filename(choice: str, current_value: str): |
| | choice_key = (choice or "cnn_lstm").lower().replace("-", "_") |
| | suggested = MODEL_FILENAME_BY_TYPE.get( |
| | choice_key, Path(LOCAL_MODEL_FILE).name |
| | ) |
| | known_defaults = set(MODEL_FILENAME_BY_TYPE.values()) |
| | current_name = Path(current_value).name if current_value else "" |
| | if current_name and current_name not in known_defaults: |
| | return gr.update() |
| | return gr.update(value=suggested) |
| |
|
| | model_selector.change( |
| | _suggest_model_filename, |
| | inputs=[model_selector, model_name], |
| | outputs=model_name, |
| | ) |
| |
|
| | with gr.Row(): |
| | train_button = gr.Button("🛠️ Start Training", variant="primary") |
| | progress_button = gr.Button( |
| | "📊 Check Progress", variant="secondary" |
| | ) |
| |
|
| | |
| | training_status = gr.Textbox(label="Training Status", interactive=False) |
| | report_output = gr.Dataframe( |
| | label="Classification report", interactive=False |
| | ) |
| | history_output = gr.JSON(label="Training history") |
| | confusion_output = gr.Dataframe( |
| | label="Confusion matrix", interactive=False |
| | ) |
| |
|
| | |
| | with gr.Accordion("📋 Progress Messages", open=True): |
| | progress_messages = gr.Textbox( |
| | label="Training Messages", |
| | lines=8, |
| | max_lines=20, |
| | interactive=False, |
| | autoscroll=True, |
| | placeholder="Click 'Check Progress' to see training updates...", |
| | ) |
| | with gr.Row(): |
| | gr.Button("🗑️ Clear Messages", variant="secondary").click( |
| | lambda: "", outputs=[progress_messages] |
| | ) |
| |
|
| | def _run_training( |
| | file_paths, |
| | label_column, |
| | model_choice, |
| | sequence_length, |
| | stride, |
| | validation_split, |
| | batch_size, |
| | epochs, |
| | output_dir, |
| | model_filename, |
| | scaler_filename, |
| | metadata_filename, |
| | enable_tensorboard, |
| | ): |
| | base_dir = normalise_output_directory(output_dir) |
| | try: |
| | base_dir.mkdir(parents=True, exist_ok=True) |
| |
|
| | model_path = resolve_output_path( |
| | base_dir, |
| | model_filename, |
| | Path(LOCAL_MODEL_FILE).name, |
| | ) |
| | scaler_path = resolve_output_path( |
| | base_dir, |
| | scaler_filename, |
| | Path(LOCAL_SCALER_FILE).name, |
| | ) |
| | metadata_path = resolve_output_path( |
| | base_dir, |
| | metadata_filename, |
| | Path(LOCAL_METADATA_FILE).name, |
| | ) |
| |
|
| | model_path.parent.mkdir(parents=True, exist_ok=True) |
| | scaler_path.parent.mkdir(parents=True, exist_ok=True) |
| | metadata_path.parent.mkdir(parents=True, exist_ok=True) |
| |
|
| | |
| | status_file = model_path.parent / "training_status.txt" |
| |
|
| | |
| | with open(status_file, "w") as f: |
| | f.write("Starting training setup...") |
| |
|
| | if not file_paths: |
| | raise ValueError( |
| | "No training CSVs were found in the database cache. " |
| | "Use 'Reload dataset from database' and try again." |
| | ) |
| |
|
| | with open(status_file, "w") as f: |
| | f.write("Loading and validating CSV files...") |
| |
|
| | available_paths = [ |
| | path for path in file_paths if Path(path).exists() |
| | ] |
| | missing_paths = [ |
| | Path(path).name |
| | for path in file_paths |
| | if not Path(path).exists() |
| | ] |
| | if not available_paths: |
| | raise ValueError( |
| | "Database training dataset is unavailable. Reload the dataset and retry." |
| | ) |
| |
|
| | dfs = [load_measurement_csv(path) for path in available_paths] |
| | combined = pd.concat(dfs, ignore_index=True) |
| |
|
| | |
| | total_samples = len(combined) |
| | if total_samples < 100: |
| | print( |
| | f"Warning: Only {total_samples} samples. Recommend at least 1000 for good results." |
| | ) |
| | print( |
| | "Automatically switching to SVM for small dataset compatibility." |
| | ) |
| | if model_choice in ["cnn_lstm", "tcn"]: |
| | model_choice = "svm" |
| | print( |
| | f"Model type changed to SVM for better small dataset performance." |
| | ) |
| | if total_samples < 10: |
| | raise ValueError( |
| | f"Insufficient data: {total_samples} samples. Need at least 10 samples for training." |
| | ) |
| |
|
| | label_column = (label_column or LABEL_COLUMN).strip() |
| | if not label_column: |
| | raise ValueError("Label column name cannot be empty.") |
| |
|
| | model_choice = ( |
| | (model_choice or "CNN-LSTM").lower().replace("-", "_") |
| | ) |
| | if model_choice not in {"cnn_lstm", "tcn", "svm"}: |
| | raise ValueError( |
| | "Select CNN-LSTM, TCN, or SVM for the model architecture." |
| | ) |
| |
|
| | with open(status_file, "w") as f: |
| | f.write( |
| | f"Starting {model_choice.upper()} training with {len(combined)} samples..." |
| | ) |
| |
|
| | |
| | result = train_from_dataframe( |
| | combined, |
| | label_column=label_column, |
| | feature_columns=None, |
| | sequence_length=int(sequence_length), |
| | stride=int(stride), |
| | validation_split=float(validation_split), |
| | batch_size=int(batch_size), |
| | epochs=int(epochs), |
| | model_type=model_choice, |
| | model_path=model_path, |
| | scaler_path=scaler_path, |
| | metadata_path=metadata_path, |
| | enable_tensorboard=bool(enable_tensorboard), |
| | ) |
| |
|
| | refresh_artifacts( |
| | Path(result["model_path"]), |
| | Path(result["scaler_path"]), |
| | Path(result["metadata_path"]), |
| | ) |
| |
|
| | report_df = classification_report_to_dataframe( |
| | result["classification_report"] |
| | ) |
| | confusion_df = confusion_matrix_to_dataframe( |
| | result["confusion_matrix"], result["class_names"] |
| | ) |
| | tensorboard_dir = result.get("tensorboard_log_dir") |
| | tensorboard_zip = result.get("tensorboard_zip_path") |
| |
|
| | architecture = result["model_type"].replace("_", "-").upper() |
| | status = ( |
| | f"Training complete using a {architecture} architecture. " |
| | f"{result['num_sequences']} windows derived from " |
| | f"{result['num_samples']} rows across {len(available_paths)} file(s)." |
| | f" Artifacts saved to:" |
| | f"\n• Model: {result['model_path']}\n" |
| | f"• Scaler: {result['scaler_path']}\n" |
| | f"• Metadata: {result['metadata_path']}" |
| | ) |
| |
|
| | status += f"\nLabel column used: {result.get('label_column', label_column)}" |
| |
|
| | if tensorboard_dir: |
| | status += ( |
| | f"\nTensorBoard logs directory: {tensorboard_dir}" |
| | f'\nRun `tensorboard --logdir "{tensorboard_dir}"` to inspect the training curves.' |
| | "\nDownload the archive below to explore the run offline." |
| | ) |
| |
|
| | if missing_paths: |
| | skipped = ", ".join(missing_paths) |
| | status = f"⚠️ Skipped missing files: {skipped}\n" + status |
| |
|
| | artifact_choices, selected_artifact = gather_artifact_choices( |
| | str(base_dir), result["model_path"] |
| | ) |
| |
|
| | return ( |
| | status, |
| | report_df, |
| | result["history"], |
| | confusion_df, |
| | download_button_state(result["model_path"]), |
| | download_button_state(result["scaler_path"]), |
| | download_button_state(result["metadata_path"]), |
| | download_button_state(tensorboard_zip), |
| | gr.update(value=result.get("label_column", label_column)), |
| | gr.update( |
| | choices=artifact_choices, value=selected_artifact |
| | ), |
| | download_button_state(selected_artifact), |
| | ) |
| | except Exception as exc: |
| | artifact_choices, selected_artifact = gather_artifact_choices( |
| | str(base_dir) |
| | ) |
| | return ( |
| | f"Training failed: {exc}", |
| | pd.DataFrame(), |
| | {}, |
| | pd.DataFrame(), |
| | download_button_state(None), |
| | download_button_state(None), |
| | download_button_state(None), |
| | download_button_state(None), |
| | gr.update(), |
| | gr.update( |
| | choices=artifact_choices, value=selected_artifact |
| | ), |
| | download_button_state(selected_artifact), |
| | ) |
| |
|
| | def _check_progress(output_dir, model_filename, current_messages): |
| | """Check training progress by reading status file and accumulate messages.""" |
| | model_path = resolve_output_path( |
| | output_dir, model_filename, Path(LOCAL_MODEL_FILE).name |
| | ) |
| | status_file = model_path.parent / "training_status.txt" |
| | status_message = read_training_status(str(status_file)) |
| |
|
| | |
| | from datetime import datetime |
| |
|
| | timestamp = datetime.now().strftime("%H:%M:%S") |
| | new_message = f"[{timestamp}] {status_message}" |
| |
|
| | |
| | if current_messages: |
| | lines = current_messages.split("\n") |
| | lines.append(new_message) |
| | |
| | if len(lines) > 50: |
| | lines = lines[-50:] |
| | accumulated_messages = "\n".join(lines) |
| | else: |
| | accumulated_messages = new_message |
| |
|
| | return accumulated_messages |
| |
|
| | train_button.click( |
| | _run_training, |
| | inputs=[ |
| | training_files_state, |
| | label_input, |
| | model_selector, |
| | sequence_length_train, |
| | stride_train, |
| | validation_train, |
| | batch_train, |
| | epochs_train, |
| | output_directory, |
| | model_name, |
| | scaler_name, |
| | metadata_name, |
| | tensorboard_toggle, |
| | ], |
| | outputs=[ |
| | training_status, |
| | report_output, |
| | history_output, |
| | confusion_output, |
| | model_download_button, |
| | scaler_download_button, |
| | metadata_download_button, |
| | tensorboard_download_button, |
| | label_input, |
| | artifact_browser, |
| | artifact_download_button, |
| | ], |
| | concurrency_limit=EVENT_CONCURRENCY_LIMIT, |
| | ) |
| |
|
| | progress_button.click( |
| | _check_progress, |
| | inputs=[output_directory, model_name, progress_messages], |
| | outputs=[progress_messages], |
| | ) |
| |
|
| | year_selector.change( |
| | on_year_change, |
| | inputs=[year_selector], |
| | outputs=[ |
| | month_selector, |
| | day_selector, |
| | available_files, |
| | repo_status, |
| | ], |
| | concurrency_limit=EVENT_CONCURRENCY_LIMIT, |
| | ) |
| |
|
| | month_selector.change( |
| | on_month_change, |
| | inputs=[year_selector, month_selector], |
| | outputs=[day_selector, available_files, repo_status], |
| | concurrency_limit=EVENT_CONCURRENCY_LIMIT, |
| | ) |
| |
|
| | day_selector.change( |
| | on_day_change, |
| | inputs=[year_selector, month_selector, day_selector], |
| | outputs=[available_files, repo_status], |
| | concurrency_limit=EVENT_CONCURRENCY_LIMIT, |
| | ) |
| |
|
| | download_button.click( |
| | download_selected_files, |
| | inputs=[ |
| | year_selector, |
| | month_selector, |
| | day_selector, |
| | available_files, |
| | label_input, |
| | ], |
| | outputs=[ |
| | training_files_state, |
| | training_files_summary, |
| | label_input, |
| | dataset_info, |
| | available_files, |
| | repo_status, |
| | ], |
| | concurrency_limit=EVENT_CONCURRENCY_LIMIT, |
| | ) |
| |
|
| | year_download_button.click( |
| | download_year_bundle, |
| | inputs=[year_selector, label_input], |
| | outputs=[ |
| | training_files_state, |
| | training_files_summary, |
| | label_input, |
| | dataset_info, |
| | available_files, |
| | repo_status, |
| | ], |
| | concurrency_limit=EVENT_CONCURRENCY_LIMIT, |
| | ) |
| |
|
| | month_download_button.click( |
| | download_month_bundle, |
| | inputs=[year_selector, month_selector, label_input], |
| | outputs=[ |
| | training_files_state, |
| | training_files_summary, |
| | label_input, |
| | dataset_info, |
| | available_files, |
| | repo_status, |
| | ], |
| | concurrency_limit=EVENT_CONCURRENCY_LIMIT, |
| | ) |
| |
|
| | day_download_button.click( |
| | download_day_bundle, |
| | inputs=[year_selector, month_selector, day_selector, label_input], |
| | outputs=[ |
| | training_files_state, |
| | training_files_summary, |
| | label_input, |
| | dataset_info, |
| | available_files, |
| | repo_status, |
| | ], |
| | concurrency_limit=EVENT_CONCURRENCY_LIMIT, |
| | ) |
| |
|
| | def _reload_dataset(current_label): |
| | local = load_repository_training_files( |
| | current_label, force_refresh=True |
| | ) |
| | remote = refresh_remote_browser(force_refresh=True) |
| | return (*local, *remote) |
| |
|
| | dataset_refresh.click( |
| | _reload_dataset, |
| | inputs=[label_input], |
| | outputs=[ |
| | training_files_state, |
| | training_files_summary, |
| | label_input, |
| | dataset_info, |
| | year_selector, |
| | month_selector, |
| | day_selector, |
| | available_files, |
| | repo_status, |
| | ], |
| | concurrency_limit=EVENT_CONCURRENCY_LIMIT, |
| | ) |
| |
|
| | clear_cache_button.click( |
| | clear_downloaded_cache, |
| | inputs=[label_input], |
| | outputs=[ |
| | training_files_state, |
| | training_files_summary, |
| | label_input, |
| | dataset_info, |
| | year_selector, |
| | month_selector, |
| | day_selector, |
| | available_files, |
| | repo_status, |
| | ], |
| | concurrency_limit=EVENT_CONCURRENCY_LIMIT, |
| | ) |
| |
|
| | def _initialise_dataset(): |
| | local = load_repository_training_files( |
| | LABEL_COLUMN, force_refresh=False |
| | ) |
| | remote = refresh_remote_browser(force_refresh=False) |
| | return (*local, *remote) |
| |
|
| | demo.load( |
| | _initialise_dataset, |
| | inputs=None, |
| | outputs=[ |
| | training_files_state, |
| | training_files_summary, |
| | label_input, |
| | dataset_info, |
| | year_selector, |
| | month_selector, |
| | day_selector, |
| | available_files, |
| | repo_status, |
| | ], |
| | queue=False, |
| | ) |
| |
|
| | return demo |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | def resolve_server_port() -> int: |
| | for env_var in ("PORT", "GRADIO_SERVER_PORT"): |
| | value = os.environ.get(env_var) |
| | if value: |
| | try: |
| | return int(value) |
| | except ValueError: |
| | print(f"Ignoring invalid port value from {env_var}: {value}") |
| | return 7860 |
| |
|
| |
|
| | def main(): |
| | print("Building Gradio interface...") |
| | try: |
| | demo = build_interface() |
| | print("Interface built successfully") |
| | except Exception as e: |
| | print(f"Failed to build interface: {e}") |
| | import traceback |
| |
|
| | traceback.print_exc() |
| | return |
| |
|
| | print("Setting up queue...") |
| | try: |
| | demo.queue(max_size=QUEUE_MAX_SIZE) |
| | print("Queue configured") |
| | except Exception as e: |
| | print(f"Failed to configure queue: {e}") |
| |
|
| | try: |
| | port = resolve_server_port() |
| | print(f"Launching Gradio app on port {port}") |
| | demo.launch(server_name="0.0.0.0", server_port=port, show_error=True) |
| | except OSError as exc: |
| | print("Failed to launch on requested port:", exc) |
| | try: |
| | demo.launch(server_name="0.0.0.0", show_error=True) |
| | except Exception as e: |
| | print(f"Failed to launch completely: {e}") |
| | except Exception as e: |
| | print(f"Unexpected launch error: {e}") |
| | import traceback |
| |
|
| | traceback.print_exc() |
| |
|
| |
|
| | if __name__ == "__main__": |
| | print("=" * 50) |
| | print("PMU Fault Classification App Starting") |
| | print(f"Python version: {os.sys.version}") |
| | print(f"Working directory: {os.getcwd()}") |
| | print(f"HUB_REPO: {HUB_REPO}") |
| | print(f"Model available: {MODEL is not None}") |
| | print(f"Scaler available: {SCALER is not None}") |
| | print("=" * 50) |
| | main() |
| |
|