GuanHuaYu student commited on
Commit
31d480d
·
1 Parent(s): a2b637a
.history/app_20251009231256.py ADDED
@@ -0,0 +1,2282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Gradio front-end for Fault_Classification_PMU_Data models.
2
+
3
+ The application loads a CNN-LSTM model (and accompanying scaler/metadata)
4
+ produced by ``fault_classification_pmu.py`` and exposes a streamlined
5
+ prediction interface optimised for Hugging Face Spaces deployment. It supports
6
+ raw PMU time-series CSV uploads as well as manual comma separated feature
7
+ vectors.
8
+ """
9
+ from __future__ import annotations
10
+
11
+ import json
12
+ import os
13
+ import shutil
14
+
15
+ os.environ.setdefault("CUDA_VISIBLE_DEVICES", "-1")
16
+ os.environ.setdefault("TF_CPP_MIN_LOG_LEVEL", "2")
17
+ os.environ.setdefault("TF_ENABLE_ONEDNN_OPTS", "0")
18
+
19
+ import re
20
+ from pathlib import Path
21
+ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
22
+
23
+ import gradio as gr
24
+ import joblib
25
+ import numpy as np
26
+ import pandas as pd
27
+ import requests
28
+ from huggingface_hub import hf_hub_download
29
+ from tensorflow.keras.models import load_model
30
+
31
+ from fault_classification_pmu import (
32
+ DEFAULT_FEATURE_COLUMNS as TRAINING_DEFAULT_FEATURE_COLUMNS,
33
+ LABEL_GUESS_CANDIDATES as TRAINING_LABEL_GUESSES,
34
+ train_from_dataframe,
35
+ )
36
+
37
+ # --------------------------------------------------------------------------------------
38
+ # Configuration
39
+ # --------------------------------------------------------------------------------------
40
+ DEFAULT_FEATURE_COLUMNS: List[str] = list(TRAINING_DEFAULT_FEATURE_COLUMNS)
41
+ DEFAULT_SEQUENCE_LENGTH = 32
42
+ DEFAULT_STRIDE = 4
43
+
44
+ LOCAL_MODEL_FILE = os.environ.get("PMU_MODEL_FILE", "pmu_cnn_lstm_model.keras")
45
+ LOCAL_SCALER_FILE = os.environ.get("PMU_SCALER_FILE", "pmu_feature_scaler.pkl")
46
+ LOCAL_METADATA_FILE = os.environ.get("PMU_METADATA_FILE", "pmu_metadata.json")
47
+
48
+ MODEL_OUTPUT_DIR = Path(os.environ.get("PMU_MODEL_DIR", "model")).resolve()
49
+ MODEL_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
50
+
51
+ HUB_REPO = os.environ.get("PMU_HUB_REPO", "")
52
+ HUB_MODEL_FILENAME = os.environ.get("PMU_HUB_MODEL_FILENAME", LOCAL_MODEL_FILE)
53
+ HUB_SCALER_FILENAME = os.environ.get("PMU_HUB_SCALER_FILENAME", LOCAL_SCALER_FILE)
54
+ HUB_METADATA_FILENAME = os.environ.get("PMU_HUB_METADATA_FILENAME", LOCAL_METADATA_FILE)
55
+
56
+ ENV_MODEL_PATH = "PMU_MODEL_PATH"
57
+ ENV_SCALER_PATH = "PMU_SCALER_PATH"
58
+ ENV_METADATA_PATH = "PMU_METADATA_PATH"
59
+
60
+ # --------------------------------------------------------------------------------------
61
+ # Utility functions for loading artifacts
62
+ # --------------------------------------------------------------------------------------
63
+
64
+ def download_from_hub(filename: str) -> Optional[Path]:
65
+ if not HUB_REPO or not filename:
66
+ return None
67
+ try:
68
+ print(f"Downloading {filename} from {HUB_REPO} ...")
69
+ # Add timeout to prevent hanging
70
+ path = hf_hub_download(repo_id=HUB_REPO, filename=filename)
71
+ print("Downloaded", path)
72
+ return Path(path)
73
+ except Exception as exc: # pragma: no cover - logging convenience
74
+ print("Failed to download", filename, "from", HUB_REPO, ":", exc)
75
+ print("Continuing without pre-trained model...")
76
+ return None
77
+
78
+
79
+ def resolve_artifact(local_name: str, env_var: str, hub_filename: str) -> Optional[Path]:
80
+ print(f"Resolving artifact: {local_name}, env: {env_var}, hub: {hub_filename}")
81
+ candidates = [Path(local_name)] if local_name else []
82
+ if local_name:
83
+ candidates.append(MODEL_OUTPUT_DIR / Path(local_name).name)
84
+ env_value = os.environ.get(env_var)
85
+ if env_value:
86
+ candidates.append(Path(env_value))
87
+
88
+ for candidate in candidates:
89
+ if candidate and candidate.exists():
90
+ print(f"Found local artifact: {candidate}")
91
+ return candidate
92
+
93
+ print(f"No local artifacts found, checking hub...")
94
+ # Only try to download if we have a hub repo configured
95
+ if HUB_REPO:
96
+ return download_from_hub(hub_filename)
97
+ else:
98
+ print("No HUB_REPO configured, skipping download")
99
+ return None
100
+
101
+
102
+ def load_metadata(path: Optional[Path]) -> Dict:
103
+ if path and path.exists():
104
+ try:
105
+ return json.loads(path.read_text())
106
+ except Exception as exc: # pragma: no cover - metadata parsing errors
107
+ print("Failed to read metadata", path, exc)
108
+ return {}
109
+
110
+
111
+ def try_load_scaler(path: Optional[Path]):
112
+ if not path:
113
+ return None
114
+ try:
115
+ scaler = joblib.load(path)
116
+ print("Loaded scaler from", path)
117
+ return scaler
118
+ except Exception as exc:
119
+ print("Failed to load scaler", path, exc)
120
+ return None
121
+
122
+
123
+ # Initialize paths with error handling
124
+ print("Starting application initialization...")
125
+ try:
126
+ MODEL_PATH = resolve_artifact(LOCAL_MODEL_FILE, ENV_MODEL_PATH, HUB_MODEL_FILENAME)
127
+ print(f"Model path resolved: {MODEL_PATH}")
128
+ except Exception as e:
129
+ print(f"Model path resolution failed: {e}")
130
+ MODEL_PATH = None
131
+
132
+ try:
133
+ SCALER_PATH = resolve_artifact(LOCAL_SCALER_FILE, ENV_SCALER_PATH, HUB_SCALER_FILENAME)
134
+ print(f"Scaler path resolved: {SCALER_PATH}")
135
+ except Exception as e:
136
+ print(f"Scaler path resolution failed: {e}")
137
+ SCALER_PATH = None
138
+
139
+ try:
140
+ METADATA_PATH = resolve_artifact(LOCAL_METADATA_FILE, ENV_METADATA_PATH, HUB_METADATA_FILENAME)
141
+ print(f"Metadata path resolved: {METADATA_PATH}")
142
+ except Exception as e:
143
+ print(f"Metadata path resolution failed: {e}")
144
+ METADATA_PATH = None
145
+
146
+ try:
147
+ METADATA = load_metadata(METADATA_PATH)
148
+ print(f"Metadata loaded: {len(METADATA)} entries")
149
+ except Exception as e:
150
+ print(f"Metadata loading failed: {e}")
151
+ METADATA = {}
152
+
153
+ # Queuing configuration
154
+ QUEUE_MAX_SIZE = 32
155
+ # Apply a small per-event concurrency limit to avoid relying on the deprecated
156
+ # ``concurrency_count`` parameter when enabling Gradio's request queue.
157
+ EVENT_CONCURRENCY_LIMIT = 2
158
+
159
+ def try_load_model(path: Optional[Path], model_type: str, model_format: str):
160
+ if not path:
161
+ return None
162
+ try:
163
+ if model_type == "svm" or model_format == "joblib":
164
+ model = joblib.load(path)
165
+ else:
166
+ model = load_model(path)
167
+ print("Loaded model from", path)
168
+ return model
169
+ except Exception as exc: # pragma: no cover - runtime diagnostics
170
+ print("Failed to load model", path, exc)
171
+ return None
172
+
173
+
174
+ FEATURE_COLUMNS: List[str] = list(DEFAULT_FEATURE_COLUMNS)
175
+ LABEL_CLASSES: List[str] = []
176
+ LABEL_COLUMN: str = "Fault"
177
+ SEQUENCE_LENGTH: int = DEFAULT_SEQUENCE_LENGTH
178
+ DEFAULT_WINDOW_STRIDE: int = DEFAULT_STRIDE
179
+ MODEL_TYPE: str = "cnn_lstm"
180
+ MODEL_FORMAT: str = "keras"
181
+
182
+ def _model_output_path(filename: str) -> str:
183
+ return str(MODEL_OUTPUT_DIR / Path(filename).name)
184
+
185
+
186
+ MODEL_FILENAME_BY_TYPE: Dict[str, str] = {
187
+ "cnn_lstm": Path(LOCAL_MODEL_FILE).name,
188
+ "tcn": "pmu_tcn_model.keras",
189
+ "svm": "pmu_svm_model.joblib",
190
+ }
191
+
192
+ REQUIRED_PMU_COLUMNS: Tuple[str, ...] = tuple(DEFAULT_FEATURE_COLUMNS)
193
+ TRAINING_UPLOAD_DIR = Path(os.environ.get("PMU_TRAINING_UPLOAD_DIR", "training_uploads"))
194
+ TRAINING_UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
195
+
196
+ TRAINING_DATA_REPO = os.environ.get("PMU_TRAINING_DATA_REPO", "VincentCroft/ThesisModelData")
197
+ TRAINING_DATA_BRANCH = os.environ.get("PMU_TRAINING_DATA_BRANCH", "main")
198
+ TRAINING_DATA_DIR = Path(os.environ.get("PMU_TRAINING_DATA_DIR", "training_dataset"))
199
+ TRAINING_DATA_DIR.mkdir(parents=True, exist_ok=True)
200
+
201
+ GITHUB_CONTENT_CACHE: Dict[str, List[Dict[str, Any]]] = {}
202
+
203
+
204
+ APP_CSS = """
205
+ #available-files-section {
206
+ position: relative;
207
+ display: flex;
208
+ flex-direction: column;
209
+ gap: 0.75rem;
210
+ border-radius: 0.75rem;
211
+ }
212
+
213
+ #available-files-grid {
214
+ position: relative;
215
+ overflow: visible;
216
+ }
217
+
218
+ #available-files-grid .form {
219
+ position: relative;
220
+ min-height: 16rem;
221
+ }
222
+
223
+ #available-files-section:has(.gradio-loading) {
224
+ isolation: isolate;
225
+ }
226
+
227
+ #available-files-grid .wrap {
228
+ display: grid;
229
+ grid-template-columns: repeat(4, minmax(0, 1fr));
230
+ gap: 0.5rem;
231
+ max-height: 24rem;
232
+ min-height: 16rem;
233
+ overflow-y: auto;
234
+ padding-right: 0.25rem;
235
+ }
236
+
237
+ #available-files-grid .wrap > div {
238
+ min-width: 0;
239
+ }
240
+
241
+ #available-files-grid .wrap label {
242
+ margin: 0;
243
+ display: flex;
244
+ align-items: center;
245
+ padding: 0.45rem 0.65rem;
246
+ border-radius: 0.65rem;
247
+ background-color: rgba(255, 255, 255, 0.05);
248
+ border: 1px solid rgba(255, 255, 255, 0.08);
249
+ transition: background-color 0.2s ease, border-color 0.2s ease;
250
+ min-height: 2.5rem;
251
+ }
252
+
253
+ #available-files-grid .wrap label:hover {
254
+ background-color: rgba(90, 200, 250, 0.16);
255
+ border-color: rgba(90, 200, 250, 0.4);
256
+ }
257
+
258
+ #available-files-grid .wrap label span {
259
+ overflow: hidden;
260
+ text-overflow: ellipsis;
261
+ white-space: nowrap;
262
+ }
263
+
264
+ #available-files-grid .gradio-loading {
265
+ position: absolute;
266
+ inset: 0;
267
+ width: auto;
268
+ height: auto;
269
+ min-height: 100%;
270
+ display: flex;
271
+ align-items: center;
272
+ justify-content: center;
273
+ background: rgba(10, 14, 23, 0.72);
274
+ border-radius: 0.75rem;
275
+ z-index: 10;
276
+ padding: 1.5rem;
277
+ pointer-events: auto;
278
+ }
279
+
280
+ #available-files-grid .gradio-loading > * {
281
+ width: 100%;
282
+ }
283
+
284
+ #available-files-grid .gradio-loading progress,
285
+ #available-files-grid .gradio-loading .progress-bar,
286
+ #available-files-grid .gradio-loading .loading-progress,
287
+ #available-files-grid .gradio-loading [role="progressbar"],
288
+ #available-files-grid .gradio-loading .wrap,
289
+ #available-files-grid .gradio-loading .inner {
290
+ width: 100% !important;
291
+ max-width: none !important;
292
+ }
293
+
294
+ #available-files-grid .gradio-loading .status,
295
+ #available-files-grid .gradio-loading .message,
296
+ #available-files-grid .gradio-loading .label {
297
+ text-align: center;
298
+ }
299
+
300
+ #date-browser-row {
301
+ gap: 0.75rem;
302
+ }
303
+
304
+ #date-browser-row .date-browser-column {
305
+ flex: 1 1 0%;
306
+ min-width: 0;
307
+ }
308
+
309
+ #date-browser-row .date-browser-column > .gradio-dropdown,
310
+ #date-browser-row .date-browser-column > .gradio-button {
311
+ width: 100%;
312
+ }
313
+
314
+ #date-browser-row .date-browser-column > .gradio-dropdown > div {
315
+ width: 100%;
316
+ }
317
+
318
+ #date-browser-row .date-browser-column .gradio-button {
319
+ justify-content: center;
320
+ }
321
+
322
+ #training-files-summary textarea {
323
+ max-height: 12rem;
324
+ overflow-y: auto;
325
+ }
326
+
327
+ #download-selected-button {
328
+ width: 100%;
329
+ position: relative;
330
+ z-index: 0;
331
+ }
332
+
333
+ #download-selected-button .gradio-button {
334
+ width: 100%;
335
+ justify-content: center;
336
+ }
337
+
338
+ #artifact-download-row {
339
+ gap: 0.75rem;
340
+ }
341
+
342
+ #artifact-download-row .artifact-download-button {
343
+ flex: 1 1 0%;
344
+ min-width: 0;
345
+ }
346
+
347
+ #artifact-download-row .artifact-download-button .gradio-button {
348
+ width: 100%;
349
+ justify-content: center;
350
+ }
351
+ """
352
+
353
+
354
+ def _github_cache_key(path: str) -> str:
355
+ return path or "__root__"
356
+
357
+
358
+ def _github_api_url(path: str) -> str:
359
+ clean_path = path.strip("/")
360
+ base = f"https://api.github.com/repos/{TRAINING_DATA_REPO}/contents"
361
+ if clean_path:
362
+ return f"{base}/{clean_path}?ref={TRAINING_DATA_BRANCH}"
363
+ return f"{base}?ref={TRAINING_DATA_BRANCH}"
364
+
365
+
366
+ def list_remote_directory(path: str = "", *, force_refresh: bool = False) -> List[Dict[str, Any]]:
367
+ key = _github_cache_key(path)
368
+ if not force_refresh and key in GITHUB_CONTENT_CACHE:
369
+ return GITHUB_CONTENT_CACHE[key]
370
+
371
+ url = _github_api_url(path)
372
+ response = requests.get(url, timeout=30)
373
+ if response.status_code != 200:
374
+ raise RuntimeError(
375
+ f"GitHub API request failed for `{path or '.'}` (status {response.status_code})."
376
+ )
377
+
378
+ payload = response.json()
379
+ if not isinstance(payload, list):
380
+ raise RuntimeError("Unexpected GitHub API payload. Expected a directory listing.")
381
+
382
+ GITHUB_CONTENT_CACHE[key] = payload
383
+ return payload
384
+
385
+
386
+ def list_remote_years(force_refresh: bool = False) -> List[str]:
387
+ entries = list_remote_directory("", force_refresh=force_refresh)
388
+ years = [item["name"] for item in entries if item.get("type") == "dir"]
389
+ return sorted(years)
390
+
391
+
392
+ def list_remote_months(year: str, *, force_refresh: bool = False) -> List[str]:
393
+ if not year:
394
+ return []
395
+ entries = list_remote_directory(year, force_refresh=force_refresh)
396
+ months = [item["name"] for item in entries if item.get("type") == "dir"]
397
+ return sorted(months)
398
+
399
+
400
+ def list_remote_days(year: str, month: str, *, force_refresh: bool = False) -> List[str]:
401
+ if not year or not month:
402
+ return []
403
+ entries = list_remote_directory(f"{year}/{month}", force_refresh=force_refresh)
404
+ days = [item["name"] for item in entries if item.get("type") == "dir"]
405
+ return sorted(days)
406
+
407
+
408
+ def list_remote_files(year: str, month: str, day: str, *, force_refresh: bool = False) -> List[str]:
409
+ if not year or not month or not day:
410
+ return []
411
+ entries = list_remote_directory(
412
+ f"{year}/{month}/{day}", force_refresh=force_refresh
413
+ )
414
+ files = [item["name"] for item in entries if item.get("type") == "file"]
415
+ return sorted(files)
416
+
417
+
418
+ def download_repository_file(year: str, month: str, day: str, filename: str) -> Path:
419
+ if not filename:
420
+ raise ValueError("Filename cannot be empty when downloading repository data.")
421
+
422
+ relative_parts = [part for part in (year, month, day, filename) if part]
423
+ if len(relative_parts) < 4:
424
+ raise ValueError("Provide year, month, day, and filename to download a CSV.")
425
+
426
+ relative_path = "/".join(relative_parts)
427
+ raw_url = (
428
+ f"https://raw.githubusercontent.com/{TRAINING_DATA_REPO}/"
429
+ f"{TRAINING_DATA_BRANCH}/{relative_path}"
430
+ )
431
+
432
+ response = requests.get(raw_url, stream=True, timeout=120)
433
+ if response.status_code != 200:
434
+ raise RuntimeError(
435
+ f"Failed to download `{relative_path}` (status {response.status_code})."
436
+ )
437
+
438
+ target_dir = TRAINING_DATA_DIR.joinpath(year, month, day)
439
+ target_dir.mkdir(parents=True, exist_ok=True)
440
+ target_path = target_dir / filename
441
+
442
+ with open(target_path, "wb") as handle:
443
+ for chunk in response.iter_content(chunk_size=1 << 20):
444
+ if chunk:
445
+ handle.write(chunk)
446
+
447
+ return target_path
448
+
449
+
450
+ def _normalise_header(name: str) -> str:
451
+ return str(name).strip().lower()
452
+
453
+
454
+ def guess_label_from_columns(columns: Sequence[str], preferred: Optional[str] = None) -> Optional[str]:
455
+ if not columns:
456
+ return preferred
457
+
458
+ lookup = {_normalise_header(col): str(col) for col in columns}
459
+
460
+ if preferred:
461
+ preferred_stripped = preferred.strip()
462
+ for col in columns:
463
+ if str(col).strip() == preferred_stripped:
464
+ return str(col)
465
+ preferred_norm = _normalise_header(preferred)
466
+ if preferred_norm in lookup:
467
+ return lookup[preferred_norm]
468
+
469
+ for guess in TRAINING_LABEL_GUESSES:
470
+ guess_norm = _normalise_header(guess)
471
+ if guess_norm in lookup:
472
+ return lookup[guess_norm]
473
+
474
+ for col in columns:
475
+ if _normalise_header(col).startswith("fault"):
476
+ return str(col)
477
+
478
+ return str(columns[0])
479
+
480
+
481
+ def summarise_training_files(paths: Sequence[str], notes: Sequence[str]) -> str:
482
+ lines = [Path(path).name for path in paths]
483
+ lines.extend(notes)
484
+ return "\n".join(lines) if lines else "No training files available."
485
+
486
+
487
+ def read_training_status(status_file_path: str) -> str:
488
+ """Read the current training status from file."""
489
+ try:
490
+ if Path(status_file_path).exists():
491
+ with open(status_file_path, 'r') as f:
492
+ return f.read().strip()
493
+ except Exception:
494
+ pass
495
+ return "Training status unavailable"
496
+
497
+
498
+ def _persist_uploaded_file(file_obj) -> Optional[Path]:
499
+ if file_obj is None:
500
+ return None
501
+
502
+ if isinstance(file_obj, (str, Path)):
503
+ source = Path(file_obj)
504
+ original_name = source.name
505
+ else:
506
+ source = Path(getattr(file_obj, "name", "") or getattr(file_obj, "path", ""))
507
+ original_name = getattr(file_obj, "orig_name", source.name) or source.name
508
+ if not source or not source.exists():
509
+ return None
510
+
511
+ original_name = Path(original_name).name or source.name
512
+
513
+ base_path = Path(original_name)
514
+ destination = TRAINING_UPLOAD_DIR / base_path.name
515
+ counter = 1
516
+ while destination.exists():
517
+ suffix = base_path.suffix or ".csv"
518
+ destination = TRAINING_UPLOAD_DIR / f"{base_path.stem}_{counter}{suffix}"
519
+ counter += 1
520
+
521
+ shutil.copy2(source, destination)
522
+ return destination
523
+
524
+
525
+ def prepare_training_paths(
526
+ paths: Sequence[str], current_label: str, cleanup_missing: bool = False
527
+ ):
528
+ valid_paths: List[str] = []
529
+ notes: List[str] = []
530
+ columns_map: Dict[str, str] = {}
531
+ for path in paths:
532
+ try:
533
+ df = load_measurement_csv(path)
534
+ except Exception as exc: # pragma: no cover - user file diagnostics
535
+ notes.append(f"⚠️ Skipped {Path(path).name}: {exc}")
536
+ if cleanup_missing:
537
+ try:
538
+ Path(path).unlink(missing_ok=True)
539
+ except Exception:
540
+ pass
541
+ continue
542
+ valid_paths.append(str(path))
543
+ for col in df.columns:
544
+ columns_map[_normalise_header(col)] = str(col)
545
+
546
+ summary = summarise_training_files(valid_paths, notes)
547
+ preferred = current_label or LABEL_COLUMN
548
+ dropdown_choices = sorted(columns_map.values()) if columns_map else [preferred or LABEL_COLUMN]
549
+ guessed = guess_label_from_columns(dropdown_choices, preferred)
550
+ dropdown_value = guessed or preferred or LABEL_COLUMN
551
+
552
+ return valid_paths, summary, gr.update(choices=dropdown_choices, value=dropdown_value)
553
+
554
+
555
+ def append_training_files(new_files, existing_paths: Sequence[str], current_label: str):
556
+ if isinstance(existing_paths, (str, Path)):
557
+ paths: List[str] = [str(existing_paths)]
558
+ elif existing_paths is None:
559
+ paths = []
560
+ else:
561
+ paths = list(existing_paths)
562
+ if new_files:
563
+ for file in new_files:
564
+ persisted = _persist_uploaded_file(file)
565
+ if persisted is None:
566
+ continue
567
+ path_str = str(persisted)
568
+ if path_str not in paths:
569
+ paths.append(path_str)
570
+
571
+ return prepare_training_paths(paths, current_label, cleanup_missing=True)
572
+
573
+
574
+ def load_repository_training_files(current_label: str, force_refresh: bool = False):
575
+ if force_refresh:
576
+ # Clearing the cache is enough because downloads are now on-demand.
577
+ for cached in list(TRAINING_DATA_DIR.glob("*")):
578
+ # On refresh we keep previously downloaded files; no deletion required.
579
+ # The flag triggers downstream UI updates only.
580
+ break
581
+
582
+ csv_paths = sorted(
583
+ str(path)
584
+ for path in TRAINING_DATA_DIR.rglob("*.csv")
585
+ if path.is_file()
586
+ )
587
+ if not csv_paths:
588
+ message = (
589
+ "No local database CSVs are available yet. Use the database browser "
590
+ "below to download specific days before training."
591
+ )
592
+ default_label = current_label or LABEL_COLUMN or "Fault"
593
+ return (
594
+ [],
595
+ message,
596
+ gr.update(choices=[default_label], value=default_label),
597
+ message,
598
+ )
599
+
600
+ valid_paths, summary, label_update = prepare_training_paths(
601
+ csv_paths, current_label, cleanup_missing=False
602
+ )
603
+
604
+ info = (
605
+ f"Ready with {len(valid_paths)} CSV file(s) cached locally under "
606
+ f"the database cache `{TRAINING_DATA_DIR}`."
607
+ )
608
+
609
+ return valid_paths, summary, label_update, info
610
+
611
+
612
+ def refresh_remote_browser(force_refresh: bool = False):
613
+ if force_refresh:
614
+ GITHUB_CONTENT_CACHE.clear()
615
+ try:
616
+ years = list_remote_years(force_refresh=force_refresh)
617
+ if years:
618
+ message = "Select a year, month, and day to list available CSV files."
619
+ else:
620
+ message = (
621
+ "⚠️ No directories were found in the database root. Verify the upstream "
622
+ "structure."
623
+ )
624
+ return (
625
+ gr.update(choices=years, value=None),
626
+ gr.update(choices=[], value=None),
627
+ gr.update(choices=[], value=None),
628
+ gr.update(choices=[], value=[]),
629
+ message,
630
+ )
631
+ except Exception as exc:
632
+ return (
633
+ gr.update(choices=[], value=None),
634
+ gr.update(choices=[], value=None),
635
+ gr.update(choices=[], value=None),
636
+ gr.update(choices=[], value=[]),
637
+ f"⚠️ Failed to query database: {exc}",
638
+ )
639
+
640
+
641
+ def on_year_change(year: Optional[str]):
642
+ if not year:
643
+ return (
644
+ gr.update(choices=[], value=None),
645
+ gr.update(choices=[], value=None),
646
+ gr.update(choices=[], value=[]),
647
+ "Select a year to continue.",
648
+ )
649
+ try:
650
+ months = list_remote_months(year)
651
+ message = (
652
+ f"Year `{year}` selected. Choose a month to drill down."
653
+ if months
654
+ else f"⚠️ No months available under `{year}`."
655
+ )
656
+ return (
657
+ gr.update(choices=months, value=None),
658
+ gr.update(choices=[], value=None),
659
+ gr.update(choices=[], value=[]),
660
+ message,
661
+ )
662
+ except Exception as exc:
663
+ return (
664
+ gr.update(choices=[], value=None),
665
+ gr.update(choices=[], value=None),
666
+ gr.update(choices=[], value=[]),
667
+ f"⚠️ Failed to list months: {exc}",
668
+ )
669
+
670
+
671
+ def on_month_change(year: Optional[str], month: Optional[str]):
672
+ if not year or not month:
673
+ return (
674
+ gr.update(choices=[], value=None),
675
+ gr.update(choices=[], value=[]),
676
+ "Select a month to continue.",
677
+ )
678
+ try:
679
+ days = list_remote_days(year, month)
680
+ message = (
681
+ f"Month `{year}/{month}` ready. Pick a day to view files."
682
+ if days
683
+ else f"⚠️ No day folders found under `{year}/{month}`."
684
+ )
685
+ return (
686
+ gr.update(choices=days, value=None),
687
+ gr.update(choices=[], value=[]),
688
+ message,
689
+ )
690
+ except Exception as exc:
691
+ return (
692
+ gr.update(choices=[], value=None),
693
+ gr.update(choices=[], value=[]),
694
+ f"⚠️ Failed to list days: {exc}",
695
+ )
696
+
697
+
698
+ def on_day_change(year: Optional[str], month: Optional[str], day: Optional[str]):
699
+ if not year or not month or not day:
700
+ return (
701
+ gr.update(choices=[], value=[]),
702
+ "Select a day to load file names.",
703
+ )
704
+ try:
705
+ files = list_remote_files(year, month, day)
706
+ message = (
707
+ f"{len(files)} file(s) available for `{year}/{month}/{day}`."
708
+ if files
709
+ else f"⚠️ No CSV files found under `{year}/{month}/{day}`."
710
+ )
711
+ return (
712
+ gr.update(choices=files, value=[]),
713
+ message,
714
+ )
715
+ except Exception as exc:
716
+ return (
717
+ gr.update(choices=[], value=[]),
718
+ f"⚠️ Failed to list files: {exc}",
719
+ )
720
+
721
+
722
+ def download_selected_files(
723
+ year: Optional[str],
724
+ month: Optional[str],
725
+ day: Optional[str],
726
+ filenames: Sequence[str],
727
+ current_label: str,
728
+ ):
729
+ if not filenames:
730
+ message = "Select at least one CSV before downloading."
731
+ local = load_repository_training_files(current_label)
732
+ return (*local, gr.update(), message)
733
+
734
+ success: List[str] = []
735
+ notes: List[str] = []
736
+ for filename in filenames:
737
+ try:
738
+ path = download_repository_file(year or "", month or "", day or "", filename)
739
+ success.append(str(path))
740
+ except Exception as exc:
741
+ notes.append(f"⚠️ {filename}: {exc}")
742
+
743
+ local = load_repository_training_files(current_label)
744
+
745
+ message_lines = []
746
+ if success:
747
+ message_lines.append(
748
+ f"Downloaded {len(success)} file(s) to the database cache `{TRAINING_DATA_DIR}`."
749
+ )
750
+ if notes:
751
+ message_lines.extend(notes)
752
+ if not message_lines:
753
+ message_lines.append("No files were downloaded.")
754
+
755
+ return (*local, gr.update(value=[]), "\n".join(message_lines))
756
+
757
+
758
+ def download_day_bundle(
759
+ year: Optional[str],
760
+ month: Optional[str],
761
+ day: Optional[str],
762
+ current_label: str,
763
+ ):
764
+ if not (year and month and day):
765
+ local = load_repository_training_files(current_label)
766
+ return (
767
+ *local,
768
+ gr.update(),
769
+ "Select a year, month, and day before downloading an entire day.",
770
+ )
771
+
772
+ try:
773
+ files = list_remote_files(year, month, day)
774
+ except Exception as exc:
775
+ local = load_repository_training_files(current_label)
776
+ return (
777
+ *local,
778
+ gr.update(),
779
+ f"⚠️ Failed to list CSVs for `{year}/{month}/{day}`: {exc}",
780
+ )
781
+
782
+ if not files:
783
+ local = load_repository_training_files(current_label)
784
+ return (
785
+ *local,
786
+ gr.update(),
787
+ f"No CSV files were found for `{year}/{month}/{day}`.",
788
+ )
789
+
790
+ result = list(download_selected_files(year, month, day, files, current_label))
791
+ result[-1] = (
792
+ f"Downloaded all {len(files)} CSV file(s) for `{year}/{month}/{day}`.\n"
793
+ f"{result[-1]}"
794
+ )
795
+ return tuple(result)
796
+
797
+
798
+ def download_month_bundle(
799
+ year: Optional[str], month: Optional[str], current_label: str
800
+ ):
801
+ if not (year and month):
802
+ local = load_repository_training_files(current_label)
803
+ return (
804
+ *local,
805
+ gr.update(),
806
+ "Select a year and month before downloading an entire month.",
807
+ )
808
+
809
+ try:
810
+ days = list_remote_days(year, month)
811
+ except Exception as exc:
812
+ local = load_repository_training_files(current_label)
813
+ return (
814
+ *local,
815
+ gr.update(),
816
+ f"⚠️ Failed to enumerate days for `{year}/{month}`: {exc}",
817
+ )
818
+
819
+ if not days:
820
+ local = load_repository_training_files(current_label)
821
+ return (
822
+ *local,
823
+ gr.update(),
824
+ f"No day folders were found for `{year}/{month}`.",
825
+ )
826
+
827
+ downloaded = 0
828
+ notes: List[str] = []
829
+ for day in days:
830
+ try:
831
+ files = list_remote_files(year, month, day)
832
+ except Exception as exc:
833
+ notes.append(f"⚠️ Failed to list `{year}/{month}/{day}`: {exc}")
834
+ continue
835
+ if not files:
836
+ notes.append(f"⚠️ No CSV files in `{year}/{month}/{day}`.")
837
+ continue
838
+ for filename in files:
839
+ try:
840
+ download_repository_file(year, month, day, filename)
841
+ downloaded += 1
842
+ except Exception as exc:
843
+ notes.append(
844
+ f"⚠️ {year}/{month}/{day}/{filename}: {exc}"
845
+ )
846
+
847
+ local = load_repository_training_files(current_label)
848
+ message_lines = []
849
+ if downloaded:
850
+ message_lines.append(
851
+ f"Downloaded {downloaded} CSV file(s) for `{year}/{month}` into the "
852
+ f"database cache `{TRAINING_DATA_DIR}`."
853
+ )
854
+ message_lines.extend(notes)
855
+ if not message_lines:
856
+ message_lines.append("No files were downloaded.")
857
+
858
+ return (*local, gr.update(value=[]), "\n".join(message_lines))
859
+
860
+
861
+ def download_year_bundle(year: Optional[str], current_label: str):
862
+ if not year:
863
+ local = load_repository_training_files(current_label)
864
+ return (
865
+ *local,
866
+ gr.update(),
867
+ "Select a year before downloading an entire year of CSVs.",
868
+ )
869
+
870
+ try:
871
+ months = list_remote_months(year)
872
+ except Exception as exc:
873
+ local = load_repository_training_files(current_label)
874
+ return (
875
+ *local,
876
+ gr.update(),
877
+ f"⚠️ Failed to enumerate months for `{year}`: {exc}",
878
+ )
879
+
880
+ if not months:
881
+ local = load_repository_training_files(current_label)
882
+ return (
883
+ *local,
884
+ gr.update(),
885
+ f"No month folders were found for `{year}`.",
886
+ )
887
+
888
+ downloaded = 0
889
+ notes: List[str] = []
890
+ for month in months:
891
+ try:
892
+ days = list_remote_days(year, month)
893
+ except Exception as exc:
894
+ notes.append(f"⚠️ Failed to list `{year}/{month}`: {exc}")
895
+ continue
896
+ if not days:
897
+ notes.append(f"⚠️ No day folders in `{year}/{month}`.")
898
+ continue
899
+ for day in days:
900
+ try:
901
+ files = list_remote_files(year, month, day)
902
+ except Exception as exc:
903
+ notes.append(f"⚠️ Failed to list `{year}/{month}/{day}`: {exc}")
904
+ continue
905
+ if not files:
906
+ notes.append(f"⚠️ No CSV files in `{year}/{month}/{day}`.")
907
+ continue
908
+ for filename in files:
909
+ try:
910
+ download_repository_file(year, month, day, filename)
911
+ downloaded += 1
912
+ except Exception as exc:
913
+ notes.append(
914
+ f"⚠️ {year}/{month}/{day}/{filename}: {exc}"
915
+ )
916
+
917
+ local = load_repository_training_files(current_label)
918
+ message_lines = []
919
+ if downloaded:
920
+ message_lines.append(
921
+ f"Downloaded {downloaded} CSV file(s) for `{year}` into the "
922
+ f"database cache `{TRAINING_DATA_DIR}`."
923
+ )
924
+ message_lines.extend(notes)
925
+ if not message_lines:
926
+ message_lines.append("No files were downloaded.")
927
+
928
+ return (*local, gr.update(value=[]), "\n".join(message_lines))
929
+
930
+
931
+ def clear_downloaded_cache(current_label: str):
932
+ status_message = ""
933
+ try:
934
+ if TRAINING_DATA_DIR.exists():
935
+ shutil.rmtree(TRAINING_DATA_DIR)
936
+ TRAINING_DATA_DIR.mkdir(parents=True, exist_ok=True)
937
+ status_message = (
938
+ f"Cleared all downloaded CSVs from database cache `{TRAINING_DATA_DIR}`."
939
+ )
940
+ except Exception as exc:
941
+ status_message = f"⚠️ Failed to clear database cache: {exc}"
942
+
943
+ local = load_repository_training_files(current_label, force_refresh=True)
944
+ remote = list(refresh_remote_browser(force_refresh=False))
945
+ if status_message:
946
+ previous = remote[-1]
947
+ if isinstance(previous, str) and previous:
948
+ remote[-1] = f"{status_message}\n{previous}"
949
+ else:
950
+ remote[-1] = status_message
951
+
952
+ return (*local, *remote)
953
+
954
+
955
+ def normalise_output_directory(directory: Optional[str]) -> Path:
956
+ base = Path(directory or MODEL_OUTPUT_DIR)
957
+ base = base.expanduser()
958
+ if not base.is_absolute():
959
+ base = (Path.cwd() / base).resolve()
960
+ return base
961
+
962
+
963
+ def resolve_output_path(
964
+ directory: Optional[Union[Path, str]], filename: Optional[str], fallback: str
965
+ ) -> Path:
966
+ if isinstance(directory, Path):
967
+ base = directory
968
+ else:
969
+ base = normalise_output_directory(directory)
970
+ candidate = Path(filename or "").expanduser()
971
+ if str(candidate):
972
+ if candidate.is_absolute():
973
+ return candidate
974
+ return (base / candidate).resolve()
975
+ return (base / fallback).resolve()
976
+
977
+
978
+ ARTIFACT_FILE_EXTENSIONS: Tuple[str, ...] = (
979
+ ".keras",
980
+ ".h5",
981
+ ".joblib",
982
+ ".pkl",
983
+ ".json",
984
+ ".onnx",
985
+ ".zip",
986
+ ".txt",
987
+ )
988
+
989
+
990
+ def gather_directory_choices(current: Optional[str]) -> Tuple[List[str], str]:
991
+ base = normalise_output_directory(current or str(MODEL_OUTPUT_DIR))
992
+ candidates = {str(base)}
993
+ try:
994
+ for candidate in base.parent.iterdir():
995
+ if candidate.is_dir():
996
+ candidates.add(str(candidate.resolve()))
997
+ except Exception:
998
+ pass
999
+ return sorted(candidates), str(base)
1000
+
1001
+
1002
+ def gather_artifact_choices(
1003
+ directory: Optional[str], selection: Optional[str] = None
1004
+ ) -> Tuple[List[Tuple[str, str]], Optional[str]]:
1005
+ base = normalise_output_directory(directory)
1006
+ choices: List[Tuple[str, str]] = []
1007
+ selected_value: Optional[str] = None
1008
+ if base.exists():
1009
+ try:
1010
+ artifacts = sorted(
1011
+ [
1012
+ path
1013
+ for path in base.iterdir()
1014
+ if path.is_file()
1015
+ and (
1016
+ not ARTIFACT_FILE_EXTENSIONS
1017
+ or path.suffix.lower() in ARTIFACT_FILE_EXTENSIONS
1018
+ )
1019
+ ],
1020
+ key=lambda path: path.name.lower(),
1021
+ )
1022
+ choices = [(artifact.name, str(artifact)) for artifact in artifacts]
1023
+ except Exception:
1024
+ choices = []
1025
+
1026
+ if selection and any(value == selection for _, value in choices):
1027
+ selected_value = selection
1028
+ elif choices:
1029
+ selected_value = choices[0][1]
1030
+
1031
+ return choices, selected_value
1032
+
1033
+
1034
+ def download_button_state(path: Optional[Union[str, Path]]):
1035
+ if not path:
1036
+ return gr.update(value=None, visible=False)
1037
+ candidate = Path(path)
1038
+ if candidate.exists():
1039
+ return gr.update(value=str(candidate), visible=True)
1040
+ return gr.update(value=None, visible=False)
1041
+
1042
+
1043
+ def clear_training_files():
1044
+ default_label = LABEL_COLUMN or "Fault"
1045
+ for cached_file in TRAINING_UPLOAD_DIR.glob("*"):
1046
+ try:
1047
+ if cached_file.is_file():
1048
+ cached_file.unlink(missing_ok=True)
1049
+ except Exception:
1050
+ pass
1051
+ return (
1052
+ [],
1053
+ "No training files selected.",
1054
+ gr.update(choices=[default_label], value=default_label),
1055
+ gr.update(value=None),
1056
+ )
1057
+
1058
+ PROJECT_OVERVIEW_MD = """
1059
+ ## Project Overview
1060
+
1061
+ This project focuses on classifying faults in electrical transmission lines and
1062
+ grid-connected photovoltaic (PV) systems by combining ensemble learning
1063
+ techniques with deep neural architectures.
1064
+
1065
+ ## Datasets
1066
+
1067
+ ### Transmission Line Fault Dataset
1068
+ - 134,406 samples collected from Phasor Measurement Units (PMUs)
1069
+ - 14 monitored channels covering currents, voltages, magnitudes, frequency, and phase angles
1070
+ - Labels span symmetrical and asymmetrical faults: NF, L-G, LL, LL-G, LLL, and LLL-G
1071
+ - Time span: 0 to 5.7 seconds with high-frequency sampling
1072
+
1073
+ ### Grid-Connected PV System Fault Dataset
1074
+ - 2,163,480 samples from 16 experimental scenarios
1075
+ - 14 features including PV array measurements (Ipv, Vpv, Vdc), three-phase currents/voltages, aggregate magnitudes (Iabc, Vabc), and frequency indicators (If, Vf)
1076
+ - Captures array, inverter, grid anomaly, feedback sensor, and MPPT controller faults at 9.9989 μs sampling intervals
1077
+
1078
+ ## Data Format Quick Reference
1079
+
1080
+ Each measurement file may be comma or tab separated and typically exposes the
1081
+ following ordered columns:
1082
+
1083
+ 1. `Timestamp`
1084
+ 2. `[325] UPMU_SUB22:FREQ` – system frequency (Hz)
1085
+ 3. `[326] UPMU_SUB22:DFDT` – frequency rate-of-change
1086
+ 4. `[327] UPMU_SUB22:FLAG` – PMU status flag
1087
+ 5. `[328] UPMU_SUB22-L1:MAG` – phase A voltage magnitude
1088
+ 6. `[329] UPMU_SUB22-L1:ANG` – phase A voltage angle
1089
+ 7. `[330] UPMU_SUB22-L2:MAG` – phase B voltage magnitude
1090
+ 8. `[331] UPMU_SUB22-L2:ANG` – phase B voltage angle
1091
+ 9. `[332] UPMU_SUB22-L3:MAG` – phase C voltage magnitude
1092
+ 10. `[333] UPMU_SUB22-L3:ANG` – phase C voltage angle
1093
+ 11. `[334] UPMU_SUB22-C1:MAG` – phase A current magnitude
1094
+ 12. `[335] UPMU_SUB22-C1:ANG` – phase A current angle
1095
+ 13. `[336] UPMU_SUB22-C2:MAG` – phase B current magnitude
1096
+ 14. `[337] UPMU_SUB22-C2:ANG` – phase B current angle
1097
+ 15. `[338] UPMU_SUB22-C3:MAG` – phase C current magnitude
1098
+ 16. `[339] UPMU_SUB22-C3:ANG` – phase C current angle
1099
+
1100
+ The training tab automatically downloads the latest CSV exports from the
1101
+ `VincentCroft/ThesisModelData` repository and concatenates them before building
1102
+ sliding windows.
1103
+
1104
+ ## Models Developed
1105
+
1106
+ 1. **Support Vector Machine (SVM)** – provides the classical machine learning baseline with balanced accuracy across both datasets (85% PMU / 83% PV).
1107
+ 2. **CNN-LSTM** – couples convolutional feature extraction with temporal memory, achieving 92% PMU / 89% PV accuracy.
1108
+ 3. **Temporal Convolutional Network (TCN)** – leverages dilated convolutions for long-range context and delivers the best trade-off between accuracy and training time (94% PMU / 91% PV).
1109
+
1110
+ ## Results Summary
1111
+
1112
+ - **Transmission Line Fault Classification**: SVM 85%, CNN-LSTM 92%, TCN 94%
1113
+ - **PV System Fault Classification**: SVM 83%, CNN-LSTM 89%, TCN 91%
1114
+
1115
+ Use the **Inference** tab to score new PMU/PV windows and the **Training** tab to
1116
+ fine-tune or retrain any of the supported models directly within Hugging Face
1117
+ Spaces. The logs panel will surface TensorBoard archives whenever deep-learning
1118
+ models are trained.
1119
+ """
1120
+
1121
+
1122
+ def load_measurement_csv(path: str) -> pd.DataFrame:
1123
+ """Read a PMU/PV measurement file with flexible separators and column mapping."""
1124
+
1125
+ try:
1126
+ df = pd.read_csv(path, sep=None, engine="python", encoding="utf-8-sig")
1127
+ except Exception:
1128
+ df = None
1129
+ for separator in ("\t", ",", ";"):
1130
+ try:
1131
+ df = pd.read_csv(path, sep=separator, engine="python", encoding="utf-8-sig")
1132
+ break
1133
+ except Exception:
1134
+ df = None
1135
+ if df is None:
1136
+ raise
1137
+
1138
+ # Clean column names
1139
+ df.columns = [str(col).strip() for col in df.columns]
1140
+
1141
+ print(f"Loaded CSV with {len(df)} rows and {len(df.columns)} columns")
1142
+ print(f"Columns: {list(df.columns)}")
1143
+ print(f"Data shape: {df.shape}")
1144
+
1145
+ # Check if we have enough data for training
1146
+ if len(df) < 100:
1147
+ print(f"Warning: Only {len(df)} rows of data. Recommend at least 1000 rows for effective training.")
1148
+
1149
+ # Check for label column
1150
+ has_label = any(col.lower() in ['fault', 'label', 'class', 'target'] for col in df.columns)
1151
+ if not has_label:
1152
+ print("Warning: No label column found. Adding dummy 'Fault' column with value 'Normal' for all samples.")
1153
+ df['Fault'] = 'Normal' # Add dummy label for training
1154
+
1155
+ # Create column mapping - map similar column names to expected format
1156
+ column_mapping = {}
1157
+ expected_cols = list(REQUIRED_PMU_COLUMNS)
1158
+
1159
+ # If we have at least the right number of numeric columns after Timestamp, use positional mapping
1160
+ if "Timestamp" in df.columns:
1161
+ numeric_cols = [col for col in df.columns if col != "Timestamp"]
1162
+ if len(numeric_cols) >= len(expected_cols):
1163
+ # Map by position (after Timestamp)
1164
+ for i, expected_col in enumerate(expected_cols):
1165
+ if i < len(numeric_cols):
1166
+ column_mapping[numeric_cols[i]] = expected_col
1167
+
1168
+ # Rename columns to match expected format
1169
+ df = df.rename(columns=column_mapping)
1170
+
1171
+ # Check if we have the required columns after mapping
1172
+ missing = [col for col in REQUIRED_PMU_COLUMNS if col not in df.columns]
1173
+ if missing:
1174
+ # If still missing, try a more flexible approach
1175
+ available_numeric = df.select_dtypes(include=[np.number]).columns.tolist()
1176
+ if len(available_numeric) >= len(expected_cols):
1177
+ # Use the first N numeric columns
1178
+ for i, expected_col in enumerate(expected_cols):
1179
+ if i < len(available_numeric):
1180
+ if available_numeric[i] not in df.columns:
1181
+ continue
1182
+ df = df.rename(columns={available_numeric[i]: expected_col})
1183
+
1184
+ # Recheck missing columns
1185
+ missing = [col for col in REQUIRED_PMU_COLUMNS if col not in df.columns]
1186
+
1187
+ if missing:
1188
+ missing_str = ", ".join(missing)
1189
+ available_str = ", ".join(df.columns.tolist())
1190
+ raise ValueError(
1191
+ f"Missing required PMU feature columns: {missing_str}. "
1192
+ f"Available columns: {available_str}. "
1193
+ "Please ensure your CSV has the correct format with Timestamp followed by PMU measurements."
1194
+ )
1195
+
1196
+ return df
1197
+
1198
+
1199
+ def apply_metadata(metadata: Dict[str, Any]) -> None:
1200
+ global FEATURE_COLUMNS, LABEL_CLASSES, LABEL_COLUMN, SEQUENCE_LENGTH, DEFAULT_WINDOW_STRIDE, MODEL_TYPE, MODEL_FORMAT
1201
+ FEATURE_COLUMNS = [str(col) for col in metadata.get("feature_columns", DEFAULT_FEATURE_COLUMNS)]
1202
+ LABEL_CLASSES = [str(label) for label in metadata.get("label_classes", [])]
1203
+ LABEL_COLUMN = str(metadata.get("label_column", "Fault"))
1204
+ SEQUENCE_LENGTH = int(metadata.get("sequence_length", DEFAULT_SEQUENCE_LENGTH))
1205
+ DEFAULT_WINDOW_STRIDE = int(metadata.get("stride", DEFAULT_STRIDE))
1206
+ MODEL_TYPE = str(metadata.get("model_type", "cnn_lstm")).lower()
1207
+ MODEL_FORMAT = str(
1208
+ metadata.get("model_format", "joblib" if MODEL_TYPE == "svm" else "keras")
1209
+ ).lower()
1210
+
1211
+
1212
+ apply_metadata(METADATA)
1213
+
1214
+ def sync_label_classes_from_model(model: Optional[object]) -> None:
1215
+ global LABEL_CLASSES
1216
+ if model is None:
1217
+ return
1218
+ if hasattr(model, "classes_"):
1219
+ LABEL_CLASSES = [str(label) for label in getattr(model, "classes_")]
1220
+ elif not LABEL_CLASSES and hasattr(model, "output_shape"):
1221
+ LABEL_CLASSES = [str(i) for i in range(int(model.output_shape[-1]))]
1222
+
1223
+
1224
+ # Load model and scaler with error handling
1225
+ print("Loading model and scaler...")
1226
+ try:
1227
+ MODEL = try_load_model(MODEL_PATH, MODEL_TYPE, MODEL_FORMAT)
1228
+ print(f"Model loaded: {MODEL is not None}")
1229
+ except Exception as e:
1230
+ print(f"Model loading failed: {e}")
1231
+ MODEL = None
1232
+
1233
+ try:
1234
+ SCALER = try_load_scaler(SCALER_PATH)
1235
+ print(f"Scaler loaded: {SCALER is not None}")
1236
+ except Exception as e:
1237
+ print(f"Scaler loading failed: {e}")
1238
+ SCALER = None
1239
+
1240
+ try:
1241
+ sync_label_classes_from_model(MODEL)
1242
+ print("Label classes synchronized")
1243
+ except Exception as e:
1244
+ print(f"Label sync failed: {e}")
1245
+
1246
+ print("Application initialization completed.")
1247
+ print(f"Ready to start Gradio interface. Model available: {MODEL is not None}, Scaler available: {SCALER is not None}")
1248
+
1249
+
1250
+ def refresh_artifacts(model_path: Path, scaler_path: Path, metadata_path: Path) -> None:
1251
+ global MODEL_PATH, SCALER_PATH, METADATA_PATH, MODEL, SCALER, METADATA
1252
+ MODEL_PATH = model_path
1253
+ SCALER_PATH = scaler_path
1254
+ METADATA_PATH = metadata_path
1255
+ METADATA = load_metadata(metadata_path)
1256
+ apply_metadata(METADATA)
1257
+ MODEL = try_load_model(model_path, MODEL_TYPE, MODEL_FORMAT)
1258
+ SCALER = try_load_scaler(scaler_path)
1259
+ sync_label_classes_from_model(MODEL)
1260
+
1261
+ # --------------------------------------------------------------------------------------
1262
+ # Pre-processing helpers
1263
+ # --------------------------------------------------------------------------------------
1264
+
1265
+ def ensure_ready():
1266
+ if MODEL is None or SCALER is None:
1267
+ raise RuntimeError(
1268
+ "The model and feature scaler are not available. Upload the trained model "
1269
+ "(for example `pmu_cnn_lstm_model.keras`, `pmu_tcn_model.keras`, or `pmu_svm_model.joblib`), "
1270
+ "the feature scaler (`pmu_feature_scaler.pkl`), and the metadata JSON (`pmu_metadata.json`) to the Space root "
1271
+ "or configure the Hugging Face Hub environment variables so the artifacts can be downloaded "
1272
+ "automatically."
1273
+ )
1274
+
1275
+
1276
+ def parse_text_features(text: str) -> np.ndarray:
1277
+ cleaned = re.sub(r"[;\n\t]+", ",", text.strip())
1278
+ arr = np.fromstring(cleaned, sep=",")
1279
+ if arr.size == 0:
1280
+ raise ValueError("No feature values were parsed. Please enter comma-separated numbers.")
1281
+ return arr.astype(np.float32)
1282
+
1283
+
1284
+ def apply_scaler(sequences: np.ndarray) -> np.ndarray:
1285
+ if SCALER is None:
1286
+ return sequences
1287
+ shape = sequences.shape
1288
+ flattened = sequences.reshape(-1, shape[-1])
1289
+ scaled = SCALER.transform(flattened)
1290
+ return scaled.reshape(shape)
1291
+
1292
+
1293
+ def make_sliding_windows(data: np.ndarray, sequence_length: int, stride: int) -> np.ndarray:
1294
+ if data.shape[0] < sequence_length:
1295
+ raise ValueError(
1296
+ f"The dataset contains {data.shape[0]} rows which is less than the requested sequence "
1297
+ f"length {sequence_length}. Provide more samples or reduce the sequence length."
1298
+ )
1299
+ windows = [data[start : start + sequence_length] for start in range(0, data.shape[0] - sequence_length + 1, stride)]
1300
+ return np.stack(windows)
1301
+
1302
+
1303
+ def dataframe_to_sequences(
1304
+ df: pd.DataFrame,
1305
+ *,
1306
+ sequence_length: int,
1307
+ stride: int,
1308
+ feature_columns: Sequence[str],
1309
+ drop_label: bool = True,
1310
+ ) -> np.ndarray:
1311
+ work_df = df.copy()
1312
+ if drop_label and LABEL_COLUMN in work_df.columns:
1313
+ work_df = work_df.drop(columns=[LABEL_COLUMN])
1314
+ if "Timestamp" in work_df.columns:
1315
+ work_df = work_df.sort_values("Timestamp")
1316
+
1317
+ available_cols = [c for c in feature_columns if c in work_df.columns]
1318
+ n_features = len(feature_columns)
1319
+ if available_cols and len(available_cols) == n_features:
1320
+ array = work_df[available_cols].astype(np.float32).to_numpy()
1321
+ return make_sliding_windows(array, sequence_length, stride)
1322
+
1323
+ numeric_df = work_df.select_dtypes(include=[np.number])
1324
+ array = numeric_df.astype(np.float32).to_numpy()
1325
+ if array.shape[1] == n_features * sequence_length:
1326
+ return array.reshape(array.shape[0], sequence_length, n_features)
1327
+ if sequence_length == 1 and array.shape[1] == n_features:
1328
+ return array.reshape(array.shape[0], 1, n_features)
1329
+ raise ValueError(
1330
+ "CSV columns do not match the expected feature layout. Include the full PMU feature set "
1331
+ "or provide pre-shaped sliding window data."
1332
+ )
1333
+
1334
+
1335
+ def label_name(index: int) -> str:
1336
+ if 0 <= index < len(LABEL_CLASSES):
1337
+ return str(LABEL_CLASSES[index])
1338
+ return f"class_{index}"
1339
+
1340
+
1341
+ def format_predictions(probabilities: np.ndarray) -> pd.DataFrame:
1342
+ rows: List[Dict[str, object]] = []
1343
+ order = np.argsort(probabilities, axis=1)[:, ::-1]
1344
+ for idx, (prob_row, ranking) in enumerate(zip(probabilities, order)):
1345
+ top_idx = int(ranking[0])
1346
+ top_label = label_name(top_idx)
1347
+ top_conf = float(prob_row[top_idx])
1348
+ top3 = [f"{label_name(i)} ({prob_row[i]*100:.2f}%)" for i in ranking[:3]]
1349
+ rows.append(
1350
+ {
1351
+ "window": idx,
1352
+ "predicted_label": top_label,
1353
+ "confidence": round(top_conf, 4),
1354
+ "top3": " | ".join(top3),
1355
+ }
1356
+ )
1357
+ return pd.DataFrame(rows)
1358
+
1359
+
1360
+ def probabilities_to_json(probabilities: np.ndarray) -> List[Dict[str, object]]:
1361
+ payload: List[Dict[str, object]] = []
1362
+ for idx, prob_row in enumerate(probabilities):
1363
+ payload.append(
1364
+ {
1365
+ "window": int(idx),
1366
+ "probabilities": {label_name(i): float(prob_row[i]) for i in range(prob_row.shape[0])},
1367
+ }
1368
+ )
1369
+ return payload
1370
+
1371
+
1372
+ def predict_sequences(sequences: np.ndarray) -> Tuple[str, pd.DataFrame, List[Dict[str, object]]]:
1373
+ ensure_ready()
1374
+ sequences = apply_scaler(sequences.astype(np.float32))
1375
+ if MODEL_TYPE == "svm":
1376
+ flattened = sequences.reshape(sequences.shape[0], -1)
1377
+ if hasattr(MODEL, "predict_proba"):
1378
+ probs = MODEL.predict_proba(flattened)
1379
+ else:
1380
+ raise RuntimeError("Loaded SVM model does not expose predict_proba. Retrain with probability=True.")
1381
+ else:
1382
+ probs = MODEL.predict(sequences, verbose=0)
1383
+ table = format_predictions(probs)
1384
+ json_probs = probabilities_to_json(probs)
1385
+ architecture = MODEL_TYPE.replace("_", "-").upper()
1386
+ status = f"Generated {len(sequences)} windows. {architecture} model output dimension: {probs.shape[1]}."
1387
+ return status, table, json_probs
1388
+
1389
+
1390
+ def predict_from_text(text: str, sequence_length: int) -> Tuple[str, pd.DataFrame, List[Dict[str, object]]]:
1391
+ arr = parse_text_features(text)
1392
+ n_features = len(FEATURE_COLUMNS)
1393
+ if arr.size % n_features != 0:
1394
+ raise ValueError(
1395
+ f"The number of values ({arr.size}) is not a multiple of the feature dimension "
1396
+ f"({n_features}). Provide values in groups of {n_features}."
1397
+ )
1398
+ timesteps = arr.size // n_features
1399
+ if timesteps != sequence_length:
1400
+ raise ValueError(
1401
+ f"Detected {timesteps} timesteps which does not match the configured sequence length "
1402
+ f"({sequence_length})."
1403
+ )
1404
+ sequences = arr.reshape(1, sequence_length, n_features)
1405
+ status, table, probs = predict_sequences(sequences)
1406
+ status = f"Single window prediction complete. {status}"
1407
+ return status, table, probs
1408
+
1409
+
1410
+ def predict_from_csv(file_obj, sequence_length: int, stride: int) -> Tuple[str, pd.DataFrame, List[Dict[str, object]]]:
1411
+ df = load_measurement_csv(file_obj.name)
1412
+ sequences = dataframe_to_sequences(
1413
+ df,
1414
+ sequence_length=sequence_length,
1415
+ stride=stride,
1416
+ feature_columns=FEATURE_COLUMNS,
1417
+ )
1418
+ status, table, probs = predict_sequences(sequences)
1419
+ status = f"CSV processed successfully. Generated {len(sequences)} windows. {status}"
1420
+ return status, table, probs
1421
+
1422
+
1423
+ # --------------------------------------------------------------------------------------
1424
+ # Training helpers
1425
+ # --------------------------------------------------------------------------------------
1426
+
1427
+
1428
+ def classification_report_to_dataframe(report: Dict[str, Any]) -> pd.DataFrame:
1429
+ rows: List[Dict[str, Any]] = []
1430
+ for label, metrics in report.items():
1431
+ if isinstance(metrics, dict):
1432
+ row = {"label": label}
1433
+ for key, value in metrics.items():
1434
+ if key == "support":
1435
+ row[key] = int(value)
1436
+ else:
1437
+ row[key] = round(float(value), 4)
1438
+ rows.append(row)
1439
+ else:
1440
+ rows.append({"label": label, "accuracy": round(float(metrics), 4)})
1441
+ return pd.DataFrame(rows)
1442
+
1443
+
1444
+ def confusion_matrix_to_dataframe(confusion: Sequence[Sequence[float]], labels: Sequence[str]) -> pd.DataFrame:
1445
+ if not confusion:
1446
+ return pd.DataFrame()
1447
+ df = pd.DataFrame(confusion, index=list(labels), columns=list(labels))
1448
+ df.index.name = "True Label"
1449
+ df.columns.name = "Predicted Label"
1450
+ return df
1451
+
1452
+
1453
+ # --------------------------------------------------------------------------------------
1454
+ # Gradio interface
1455
+ # --------------------------------------------------------------------------------------
1456
+
1457
+ def build_interface() -> gr.Blocks:
1458
+ theme = gr.themes.Soft(primary_hue="sky", secondary_hue="blue", neutral_hue="gray").set(
1459
+ body_background_fill="#1f1f1f",
1460
+ body_text_color="#f5f5f5",
1461
+ block_background_fill="#262626",
1462
+ block_border_color="#333333",
1463
+ button_primary_background_fill="#5ac8fa",
1464
+ button_primary_background_fill_hover="#48b5eb",
1465
+ button_primary_border_color="#38bdf8",
1466
+ button_primary_text_color="#0f172a",
1467
+ button_secondary_background_fill="#3f3f46",
1468
+ button_secondary_text_color="#f5f5f5",
1469
+ )
1470
+
1471
+ def _normalise_directory_string(value: Optional[Union[str, Path]]) -> str:
1472
+ if value is None:
1473
+ return ""
1474
+ path = Path(value).expanduser()
1475
+ try:
1476
+ return str(path.resolve())
1477
+ except Exception:
1478
+ return str(path)
1479
+
1480
+ with gr.Blocks(title="Fault Classification - PMU Data", theme=theme, css=APP_CSS) as demo:
1481
+ gr.Markdown("# Fault Classification for PMU & PV Data")
1482
+ gr.Markdown(
1483
+ "🖥️ TensorFlow is locked to CPU execution so the Space can run without CUDA drivers."
1484
+ )
1485
+ if MODEL is None or SCALER is None:
1486
+ gr.Markdown(
1487
+ "⚠️ **Artifacts Missing** — Upload `pmu_cnn_lstm_model.keras`, "
1488
+ "`pmu_feature_scaler.pkl`, and `pmu_metadata.json` to enable inference, "
1489
+ "or configure the Hugging Face Hub environment variables so they can be downloaded."
1490
+ )
1491
+ else:
1492
+ class_count = len(LABEL_CLASSES) if LABEL_CLASSES else "unknown"
1493
+ gr.Markdown(
1494
+ f"Loaded a **{MODEL_TYPE.upper()}** model ({MODEL_FORMAT.upper()}) with "
1495
+ f"{len(FEATURE_COLUMNS)} features, sequence length **{SEQUENCE_LENGTH}**, and "
1496
+ f"{class_count} target classes. Use the tabs below to run inference or fine-tune "
1497
+ "the model with your own CSV files."
1498
+ )
1499
+
1500
+ with gr.Accordion("Feature Reference", open=False):
1501
+ gr.Markdown(
1502
+ f"Each time window expects **{len(FEATURE_COLUMNS)} features** ordered as follows:\n"
1503
+ + "\n".join(f"- {name}" for name in FEATURE_COLUMNS)
1504
+ )
1505
+ gr.Markdown(
1506
+ f"Default training parameters: **sequence length = {SEQUENCE_LENGTH}**, "
1507
+ f"**stride = {DEFAULT_WINDOW_STRIDE}**. Adjust them in the tabs as needed."
1508
+ )
1509
+
1510
+ with gr.Tabs():
1511
+ with gr.Tab("Overview"):
1512
+ gr.Markdown(PROJECT_OVERVIEW_MD)
1513
+ with gr.Tab("Inference"):
1514
+ gr.Markdown("## Run Inference")
1515
+ with gr.Row():
1516
+ file_in = gr.File(label="Upload PMU CSV", file_types=[".csv"])
1517
+ text_in = gr.Textbox(
1518
+ lines=4,
1519
+ label="Or paste a single window (comma separated)",
1520
+ placeholder="49.97772,1.215825E-38,...",
1521
+ )
1522
+
1523
+ with gr.Row():
1524
+ sequence_length_input = gr.Slider(
1525
+ minimum=1,
1526
+ maximum=max(1, SEQUENCE_LENGTH * 2),
1527
+ step=1,
1528
+ value=SEQUENCE_LENGTH,
1529
+ label="Sequence length (timesteps)",
1530
+ )
1531
+ stride_input = gr.Slider(
1532
+ minimum=1,
1533
+ maximum=max(1, SEQUENCE_LENGTH),
1534
+ step=1,
1535
+ value=max(1, DEFAULT_WINDOW_STRIDE),
1536
+ label="CSV window stride",
1537
+ )
1538
+
1539
+ predict_btn = gr.Button("🚀 Run Inference", variant="primary")
1540
+ status_out = gr.Textbox(label="Status", interactive=False)
1541
+ table_out = gr.Dataframe(
1542
+ headers=["window", "predicted_label", "confidence", "top3"],
1543
+ label="Predictions",
1544
+ interactive=False,
1545
+ )
1546
+ probs_out = gr.JSON(label="Per-window probabilities")
1547
+
1548
+ def _run_prediction(file_obj, text, sequence_length, stride):
1549
+ sequence_length = int(sequence_length)
1550
+ stride = int(stride)
1551
+ try:
1552
+ if file_obj is not None:
1553
+ return predict_from_csv(file_obj, sequence_length, stride)
1554
+ if text and text.strip():
1555
+ return predict_from_text(text, sequence_length)
1556
+ return "Please upload a CSV file or provide feature values.", pd.DataFrame(), []
1557
+ except Exception as exc:
1558
+ return f"Prediction failed: {exc}", pd.DataFrame(), []
1559
+
1560
+ predict_btn.click(
1561
+ _run_prediction,
1562
+ inputs=[file_in, text_in, sequence_length_input, stride_input],
1563
+ outputs=[status_out, table_out, probs_out],
1564
+ concurrency_limit=EVENT_CONCURRENCY_LIMIT,
1565
+ )
1566
+
1567
+ with gr.Tab("Training"):
1568
+ gr.Markdown("## Train or Fine-tune the Model")
1569
+ gr.Markdown(
1570
+ "Training data is automatically downloaded from the database. "
1571
+ "Refresh the cache if new files are added upstream."
1572
+ )
1573
+
1574
+ training_files_state = gr.State([])
1575
+ with gr.Row():
1576
+ with gr.Column(scale=3):
1577
+ training_files_summary = gr.Textbox(
1578
+ label="Database training CSVs",
1579
+ value="Training dataset not loaded yet.",
1580
+ lines=4,
1581
+ interactive=False,
1582
+ elem_id="training-files-summary",
1583
+ )
1584
+ with gr.Column(scale=2, min_width=240):
1585
+ dataset_info = gr.Markdown(
1586
+ "No local database CSVs downloaded yet.",
1587
+ )
1588
+ dataset_refresh = gr.Button(
1589
+ "🔄 Reload dataset from database",
1590
+ variant="secondary",
1591
+ )
1592
+ clear_cache_button = gr.Button(
1593
+ "🧹 Clear downloaded cache",
1594
+ variant="secondary",
1595
+ )
1596
+
1597
+ with gr.Accordion("📂 DataBaseBrowser", open=False):
1598
+ gr.Markdown(
1599
+ "Browse the upstream database by date and download only the CSVs you need."
1600
+ )
1601
+ with gr.Row(elem_id="date-browser-row"):
1602
+ with gr.Column(scale=1, elem_classes=["date-browser-column"]):
1603
+ year_selector = gr.Dropdown(label="Year", choices=[])
1604
+ year_download_button = gr.Button(
1605
+ "⬇️ Download year CSVs", variant="secondary"
1606
+ )
1607
+ with gr.Column(scale=1, elem_classes=["date-browser-column"]):
1608
+ month_selector = gr.Dropdown(label="Month", choices=[])
1609
+ month_download_button = gr.Button(
1610
+ "⬇️ Download month CSVs", variant="secondary"
1611
+ )
1612
+ with gr.Column(scale=1, elem_classes=["date-browser-column"]):
1613
+ day_selector = gr.Dropdown(label="Day", choices=[])
1614
+ day_download_button = gr.Button(
1615
+ "⬇️ Download day CSVs", variant="secondary"
1616
+ )
1617
+ with gr.Column(elem_id="available-files-section"):
1618
+ available_files = gr.CheckboxGroup(
1619
+ label="Available CSV files",
1620
+ choices=[],
1621
+ value=[],
1622
+ elem_id="available-files-grid",
1623
+ )
1624
+ download_button = gr.Button(
1625
+ "⬇️ Download selected CSVs",
1626
+ variant="secondary",
1627
+ elem_id="download-selected-button",
1628
+ )
1629
+ repo_status = gr.Markdown(
1630
+ "Click 'Reload dataset from database' to fetch the directory tree."
1631
+ )
1632
+
1633
+ with gr.Row():
1634
+ label_input = gr.Dropdown(
1635
+ value=LABEL_COLUMN,
1636
+ choices=[LABEL_COLUMN],
1637
+ allow_custom_value=True,
1638
+ label="Label column name",
1639
+ )
1640
+ model_selector = gr.Radio(
1641
+ choices=["CNN-LSTM", "TCN", "SVM"],
1642
+ value=(
1643
+ "TCN"
1644
+ if MODEL_TYPE == "tcn"
1645
+ else ("SVM" if MODEL_TYPE == "svm" else "CNN-LSTM")
1646
+ ),
1647
+ label="Model architecture",
1648
+ )
1649
+ sequence_length_train = gr.Slider(
1650
+ minimum=4,
1651
+ maximum=max(32, SEQUENCE_LENGTH * 2),
1652
+ step=1,
1653
+ value=SEQUENCE_LENGTH,
1654
+ label="Sequence length",
1655
+ )
1656
+ stride_train = gr.Slider(
1657
+ minimum=1,
1658
+ maximum=max(32, SEQUENCE_LENGTH * 2),
1659
+ step=1,
1660
+ value=max(1, DEFAULT_WINDOW_STRIDE),
1661
+ label="Stride",
1662
+ )
1663
+
1664
+ model_default = MODEL_FILENAME_BY_TYPE.get(
1665
+ MODEL_TYPE, Path(LOCAL_MODEL_FILE).name
1666
+ )
1667
+
1668
+ with gr.Row():
1669
+ validation_train = gr.Slider(
1670
+ minimum=0.05,
1671
+ maximum=0.4,
1672
+ step=0.05,
1673
+ value=0.2,
1674
+ label="Validation split",
1675
+ )
1676
+ batch_train = gr.Slider(
1677
+ minimum=32,
1678
+ maximum=512,
1679
+ step=32,
1680
+ value=128,
1681
+ label="Batch size",
1682
+ )
1683
+ epochs_train = gr.Slider(
1684
+ minimum=5,
1685
+ maximum=100,
1686
+ step=5,
1687
+ value=50,
1688
+ label="Epochs",
1689
+ )
1690
+
1691
+ directory_choices, directory_default = gather_directory_choices(
1692
+ str(MODEL_OUTPUT_DIR)
1693
+ )
1694
+ artifact_choices, default_artifact = gather_artifact_choices(
1695
+ directory_default
1696
+ )
1697
+
1698
+ with gr.Row():
1699
+ output_directory = gr.Dropdown(
1700
+ value=directory_default,
1701
+ label="Output directory",
1702
+ choices=directory_choices,
1703
+ allow_custom_value=True,
1704
+ )
1705
+ model_name = gr.Textbox(
1706
+ value=model_default,
1707
+ label="Model output filename",
1708
+ )
1709
+ scaler_name = gr.Textbox(
1710
+ value=Path(LOCAL_SCALER_FILE).name,
1711
+ label="Scaler output filename",
1712
+ )
1713
+ metadata_name = gr.Textbox(
1714
+ value=Path(LOCAL_METADATA_FILE).name,
1715
+ label="Metadata output filename",
1716
+ )
1717
+
1718
+ with gr.Row():
1719
+ artifact_browser = gr.Dropdown(
1720
+ label="Saved artifacts in directory",
1721
+ choices=artifact_choices,
1722
+ value=default_artifact,
1723
+ )
1724
+ artifact_download_button = gr.DownloadButton(
1725
+ "⬇️ Download selected artifact",
1726
+ value=default_artifact,
1727
+ visible=bool(default_artifact),
1728
+ variant="secondary",
1729
+ )
1730
+
1731
+ def on_output_directory_change(selected_dir, current_selection):
1732
+ choices, normalised = gather_directory_choices(selected_dir)
1733
+ artifact_options, selected = gather_artifact_choices(
1734
+ normalised, current_selection
1735
+ )
1736
+ return (
1737
+ gr.update(choices=choices, value=normalised),
1738
+ gr.update(choices=artifact_options, value=selected),
1739
+ download_button_state(selected),
1740
+ )
1741
+
1742
+ def on_artifact_change(selected_path):
1743
+ return download_button_state(selected_path)
1744
+
1745
+ output_directory.change(
1746
+ on_output_directory_change,
1747
+ inputs=[output_directory, artifact_browser],
1748
+ outputs=[output_directory, artifact_browser, artifact_download_button],
1749
+ concurrency_limit=EVENT_CONCURRENCY_LIMIT,
1750
+ )
1751
+
1752
+ artifact_browser.change(
1753
+ on_artifact_change,
1754
+ inputs=[artifact_browser],
1755
+ outputs=[artifact_download_button],
1756
+ concurrency_limit=EVENT_CONCURRENCY_LIMIT,
1757
+ )
1758
+
1759
+ with gr.Row(elem_id="artifact-download-row"):
1760
+ model_download_button = gr.DownloadButton(
1761
+ "⬇️ Download model file",
1762
+ value=None,
1763
+ visible=False,
1764
+ elem_classes=["artifact-download-button"],
1765
+ )
1766
+ scaler_download_button = gr.DownloadButton(
1767
+ "⬇️ Download scaler file",
1768
+ value=None,
1769
+ visible=False,
1770
+ elem_classes=["artifact-download-button"],
1771
+ )
1772
+ metadata_download_button = gr.DownloadButton(
1773
+ "⬇️ Download metadata file",
1774
+ value=None,
1775
+ visible=False,
1776
+ elem_classes=["artifact-download-button"],
1777
+ )
1778
+ tensorboard_download_button = gr.DownloadButton(
1779
+ "⬇️ Download TensorBoard logs",
1780
+ value=None,
1781
+ visible=False,
1782
+ elem_classes=["artifact-download-button"],
1783
+ )
1784
+
1785
+ model_download_button.file_name = Path(LOCAL_MODEL_FILE).name
1786
+ scaler_download_button.file_name = Path(LOCAL_SCALER_FILE).name
1787
+ metadata_download_button.file_name = Path(LOCAL_METADATA_FILE).name
1788
+ tensorboard_download_button.file_name = "tensorboard_logs.zip"
1789
+
1790
+ tensorboard_toggle = gr.Checkbox(
1791
+ value=True,
1792
+ label="Enable TensorBoard logging (creates downloadable archive)",
1793
+ )
1794
+
1795
+ def _suggest_model_filename(choice: str, current_value: str):
1796
+ choice_key = (choice or "cnn_lstm").lower().replace("-", "_")
1797
+ suggested = MODEL_FILENAME_BY_TYPE.get(
1798
+ choice_key, Path(LOCAL_MODEL_FILE).name
1799
+ )
1800
+ known_defaults = set(MODEL_FILENAME_BY_TYPE.values())
1801
+ current_name = Path(current_value).name if current_value else ""
1802
+ if current_name and current_name not in known_defaults:
1803
+ return gr.update()
1804
+ return gr.update(value=suggested)
1805
+
1806
+ model_selector.change(
1807
+ _suggest_model_filename,
1808
+ inputs=[model_selector, model_name],
1809
+ outputs=model_name,
1810
+ )
1811
+
1812
+ with gr.Row():
1813
+ train_button = gr.Button("🛠️ Start Training", variant="primary")
1814
+ progress_button = gr.Button("📊 Check Progress", variant="secondary")
1815
+
1816
+ # Training status display
1817
+ training_status = gr.Textbox(label="Training Status", interactive=False)
1818
+ report_output = gr.Dataframe(label="Classification report", interactive=False)
1819
+ history_output = gr.JSON(label="Training history")
1820
+ confusion_output = gr.Dataframe(label="Confusion matrix", interactive=False)
1821
+
1822
+ # Message area at the bottom for progress updates
1823
+ with gr.Accordion("📋 Progress Messages", open=True):
1824
+ progress_messages = gr.Textbox(
1825
+ label="Training Messages",
1826
+ lines=8,
1827
+ max_lines=20,
1828
+ interactive=False,
1829
+ autoscroll=True,
1830
+ placeholder="Click 'Check Progress' to see training updates..."
1831
+ )
1832
+ with gr.Row():
1833
+ gr.Button("🗑️ Clear Messages", variant="secondary").click(
1834
+ lambda: "",
1835
+ outputs=[progress_messages]
1836
+ )
1837
+
1838
+ def _run_training(
1839
+ file_paths,
1840
+ label_column,
1841
+ model_choice,
1842
+ sequence_length,
1843
+ stride,
1844
+ validation_split,
1845
+ batch_size,
1846
+ epochs,
1847
+ output_dir,
1848
+ model_filename,
1849
+ scaler_filename,
1850
+ metadata_filename,
1851
+ enable_tensorboard,
1852
+ ):
1853
+ base_dir = normalise_output_directory(output_dir)
1854
+ try:
1855
+ base_dir.mkdir(parents=True, exist_ok=True)
1856
+
1857
+ model_path = resolve_output_path(
1858
+ base_dir,
1859
+ model_filename,
1860
+ Path(LOCAL_MODEL_FILE).name,
1861
+ )
1862
+ scaler_path = resolve_output_path(
1863
+ base_dir,
1864
+ scaler_filename,
1865
+ Path(LOCAL_SCALER_FILE).name,
1866
+ )
1867
+ metadata_path = resolve_output_path(
1868
+ base_dir,
1869
+ metadata_filename,
1870
+ Path(LOCAL_METADATA_FILE).name,
1871
+ )
1872
+
1873
+ model_path.parent.mkdir(parents=True, exist_ok=True)
1874
+ scaler_path.parent.mkdir(parents=True, exist_ok=True)
1875
+ metadata_path.parent.mkdir(parents=True, exist_ok=True)
1876
+
1877
+ # Create status file path for progress tracking
1878
+ status_file = model_path.parent / "training_status.txt"
1879
+
1880
+ # Initialize status
1881
+ with open(status_file, 'w') as f:
1882
+ f.write("Starting training setup...")
1883
+
1884
+ if not file_paths:
1885
+ raise ValueError(
1886
+ "No training CSVs were found in the database cache. "
1887
+ "Use 'Reload dataset from database' and try again."
1888
+ )
1889
+
1890
+ with open(status_file, 'w') as f:
1891
+ f.write("Loading and validating CSV files...")
1892
+
1893
+ available_paths = [path for path in file_paths if Path(path).exists()]
1894
+ missing_paths = [Path(path).name for path in file_paths if not Path(path).exists()]
1895
+ if not available_paths:
1896
+ raise ValueError(
1897
+ "Database training dataset is unavailable. Reload the dataset and retry."
1898
+ )
1899
+
1900
+ dfs = [load_measurement_csv(path) for path in available_paths]
1901
+ combined = pd.concat(dfs, ignore_index=True)
1902
+
1903
+ # Validate data size and provide recommendations
1904
+ total_samples = len(combined)
1905
+ if total_samples < 100:
1906
+ print(f"Warning: Only {total_samples} samples. Recommend at least 1000 for good results.")
1907
+ print("Automatically switching to SVM for small dataset compatibility.")
1908
+ if model_choice in ["cnn_lstm", "tcn"]:
1909
+ model_choice = "svm"
1910
+ print(f"Model type changed to SVM for better small dataset performance.")
1911
+ if total_samples < 10:
1912
+ raise ValueError(f"Insufficient data: {total_samples} samples. Need at least 10 samples for training.")
1913
+
1914
+ label_column = (label_column or LABEL_COLUMN).strip()
1915
+ if not label_column:
1916
+ raise ValueError("Label column name cannot be empty.")
1917
+
1918
+ model_choice = (model_choice or "CNN-LSTM").lower().replace("-", "_")
1919
+ if model_choice not in {"cnn_lstm", "tcn", "svm"}:
1920
+ raise ValueError("Select CNN-LSTM, TCN, or SVM for the model architecture.")
1921
+
1922
+ with open(status_file, 'w') as f:
1923
+ f.write(f"Starting {model_choice.upper()} training with {len(combined)} samples...")
1924
+
1925
+ # Start training
1926
+ result = train_from_dataframe(
1927
+ combined,
1928
+ label_column=label_column,
1929
+ feature_columns=None,
1930
+ sequence_length=int(sequence_length),
1931
+ stride=int(stride),
1932
+ validation_split=float(validation_split),
1933
+ batch_size=int(batch_size),
1934
+ epochs=int(epochs),
1935
+ model_type=model_choice,
1936
+ model_path=model_path,
1937
+ scaler_path=scaler_path,
1938
+ metadata_path=metadata_path,
1939
+ enable_tensorboard=bool(enable_tensorboard),
1940
+ )
1941
+
1942
+ refresh_artifacts(
1943
+ Path(result["model_path"]),
1944
+ Path(result["scaler_path"]),
1945
+ Path(result["metadata_path"]),
1946
+ )
1947
+
1948
+ report_df = classification_report_to_dataframe(result["classification_report"])
1949
+ confusion_df = confusion_matrix_to_dataframe(result["confusion_matrix"], result["class_names"])
1950
+ tensorboard_dir = result.get("tensorboard_log_dir")
1951
+ tensorboard_zip = result.get("tensorboard_zip_path")
1952
+
1953
+ architecture = result["model_type"].replace("_", "-").upper()
1954
+ status = (
1955
+ f"Training complete using a {architecture} architecture. "
1956
+ f"{result['num_sequences']} windows derived from "
1957
+ f"{result['num_samples']} rows across {len(available_paths)} file(s)."
1958
+ f" Artifacts saved to:"
1959
+ f"\n• Model: {result['model_path']}\n"
1960
+ f"• Scaler: {result['scaler_path']}\n"
1961
+ f"• Metadata: {result['metadata_path']}"
1962
+ )
1963
+
1964
+ status += f"\nLabel column used: {result.get('label_column', label_column)}"
1965
+
1966
+ if tensorboard_dir:
1967
+ status += (
1968
+ f"\nTensorBoard logs directory: {tensorboard_dir}"
1969
+ f"\nRun `tensorboard --logdir \"{tensorboard_dir}\"` to inspect the training curves."
1970
+ "\nDownload the archive below to explore the run offline."
1971
+ )
1972
+
1973
+ if missing_paths:
1974
+ skipped = ", ".join(missing_paths)
1975
+ status = f"⚠️ Skipped missing files: {skipped}\n" + status
1976
+
1977
+ artifact_choices, selected_artifact = gather_artifact_choices(
1978
+ str(base_dir), result["model_path"]
1979
+ )
1980
+
1981
+ return (
1982
+ status,
1983
+ report_df,
1984
+ result["history"],
1985
+ confusion_df,
1986
+ download_button_state(result["model_path"]),
1987
+ download_button_state(result["scaler_path"]),
1988
+ download_button_state(result["metadata_path"]),
1989
+ download_button_state(tensorboard_zip),
1990
+ gr.update(value=result.get("label_column", label_column)),
1991
+ gr.update(choices=artifact_choices, value=selected_artifact),
1992
+ download_button_state(selected_artifact),
1993
+ )
1994
+ except Exception as exc:
1995
+ artifact_choices, selected_artifact = gather_artifact_choices(
1996
+ str(base_dir)
1997
+ )
1998
+ return (
1999
+ f"Training failed: {exc}",
2000
+ pd.DataFrame(),
2001
+ {},
2002
+ pd.DataFrame(),
2003
+ download_button_state(None),
2004
+ download_button_state(None),
2005
+ download_button_state(None),
2006
+ download_button_state(None),
2007
+ gr.update(),
2008
+ gr.update(choices=artifact_choices, value=selected_artifact),
2009
+ download_button_state(selected_artifact),
2010
+ )
2011
+
2012
+ def _check_progress(output_dir, model_filename, current_messages):
2013
+ """Check training progress by reading status file and accumulate messages."""
2014
+ model_path = resolve_output_path(
2015
+ output_dir, model_filename, Path(LOCAL_MODEL_FILE).name
2016
+ )
2017
+ status_file = model_path.parent / "training_status.txt"
2018
+ status_message = read_training_status(str(status_file))
2019
+
2020
+ # Add timestamp to the message
2021
+ from datetime import datetime
2022
+ timestamp = datetime.now().strftime("%H:%M:%S")
2023
+ new_message = f"[{timestamp}] {status_message}"
2024
+
2025
+ # Accumulate messages, keeping last 50 lines to prevent overflow
2026
+ if current_messages:
2027
+ lines = current_messages.split('\n')
2028
+ lines.append(new_message)
2029
+ # Keep only last 50 lines
2030
+ if len(lines) > 50:
2031
+ lines = lines[-50:]
2032
+ accumulated_messages = '\n'.join(lines)
2033
+ else:
2034
+ accumulated_messages = new_message
2035
+
2036
+ return accumulated_messages
2037
+
2038
+ train_button.click(
2039
+ _run_training,
2040
+ inputs=[
2041
+ training_files_state,
2042
+ label_input,
2043
+ model_selector,
2044
+ sequence_length_train,
2045
+ stride_train,
2046
+ validation_train,
2047
+ batch_train,
2048
+ epochs_train,
2049
+ output_directory,
2050
+ model_name,
2051
+ scaler_name,
2052
+ metadata_name,
2053
+ tensorboard_toggle,
2054
+ ],
2055
+ outputs=[
2056
+ training_status,
2057
+ report_output,
2058
+ history_output,
2059
+ confusion_output,
2060
+ model_download_button,
2061
+ scaler_download_button,
2062
+ metadata_download_button,
2063
+ tensorboard_download_button,
2064
+ label_input,
2065
+ artifact_browser,
2066
+ artifact_download_button,
2067
+ ],
2068
+ concurrency_limit=EVENT_CONCURRENCY_LIMIT,
2069
+ )
2070
+
2071
+ progress_button.click(
2072
+ _check_progress,
2073
+ inputs=[output_directory, model_name, progress_messages],
2074
+ outputs=[progress_messages],
2075
+ )
2076
+
2077
+ year_selector.change(
2078
+ on_year_change,
2079
+ inputs=[year_selector],
2080
+ outputs=[month_selector, day_selector, available_files, repo_status],
2081
+ concurrency_limit=EVENT_CONCURRENCY_LIMIT,
2082
+ )
2083
+
2084
+ month_selector.change(
2085
+ on_month_change,
2086
+ inputs=[year_selector, month_selector],
2087
+ outputs=[day_selector, available_files, repo_status],
2088
+ concurrency_limit=EVENT_CONCURRENCY_LIMIT,
2089
+ )
2090
+
2091
+ day_selector.change(
2092
+ on_day_change,
2093
+ inputs=[year_selector, month_selector, day_selector],
2094
+ outputs=[available_files, repo_status],
2095
+ concurrency_limit=EVENT_CONCURRENCY_LIMIT,
2096
+ )
2097
+
2098
+ download_button.click(
2099
+ download_selected_files,
2100
+ inputs=[
2101
+ year_selector,
2102
+ month_selector,
2103
+ day_selector,
2104
+ available_files,
2105
+ label_input,
2106
+ ],
2107
+ outputs=[
2108
+ training_files_state,
2109
+ training_files_summary,
2110
+ label_input,
2111
+ dataset_info,
2112
+ available_files,
2113
+ repo_status,
2114
+ ],
2115
+ concurrency_limit=EVENT_CONCURRENCY_LIMIT,
2116
+ )
2117
+
2118
+ year_download_button.click(
2119
+ download_year_bundle,
2120
+ inputs=[year_selector, label_input],
2121
+ outputs=[
2122
+ training_files_state,
2123
+ training_files_summary,
2124
+ label_input,
2125
+ dataset_info,
2126
+ available_files,
2127
+ repo_status,
2128
+ ],
2129
+ concurrency_limit=EVENT_CONCURRENCY_LIMIT,
2130
+ )
2131
+
2132
+ month_download_button.click(
2133
+ download_month_bundle,
2134
+ inputs=[year_selector, month_selector, label_input],
2135
+ outputs=[
2136
+ training_files_state,
2137
+ training_files_summary,
2138
+ label_input,
2139
+ dataset_info,
2140
+ available_files,
2141
+ repo_status,
2142
+ ],
2143
+ concurrency_limit=EVENT_CONCURRENCY_LIMIT,
2144
+ )
2145
+
2146
+ day_download_button.click(
2147
+ download_day_bundle,
2148
+ inputs=[year_selector, month_selector, day_selector, label_input],
2149
+ outputs=[
2150
+ training_files_state,
2151
+ training_files_summary,
2152
+ label_input,
2153
+ dataset_info,
2154
+ available_files,
2155
+ repo_status,
2156
+ ],
2157
+ concurrency_limit=EVENT_CONCURRENCY_LIMIT,
2158
+ )
2159
+
2160
+ def _reload_dataset(current_label):
2161
+ local = load_repository_training_files(current_label, force_refresh=True)
2162
+ remote = refresh_remote_browser(force_refresh=True)
2163
+ return (*local, *remote)
2164
+
2165
+ dataset_refresh.click(
2166
+ _reload_dataset,
2167
+ inputs=[label_input],
2168
+ outputs=[
2169
+ training_files_state,
2170
+ training_files_summary,
2171
+ label_input,
2172
+ dataset_info,
2173
+ year_selector,
2174
+ month_selector,
2175
+ day_selector,
2176
+ available_files,
2177
+ repo_status,
2178
+ ],
2179
+ concurrency_limit=EVENT_CONCURRENCY_LIMIT,
2180
+ )
2181
+
2182
+ clear_cache_button.click(
2183
+ clear_downloaded_cache,
2184
+ inputs=[label_input],
2185
+ outputs=[
2186
+ training_files_state,
2187
+ training_files_summary,
2188
+ label_input,
2189
+ dataset_info,
2190
+ year_selector,
2191
+ month_selector,
2192
+ day_selector,
2193
+ available_files,
2194
+ repo_status,
2195
+ ],
2196
+ concurrency_limit=EVENT_CONCURRENCY_LIMIT,
2197
+ )
2198
+
2199
+ def _initialise_dataset():
2200
+ local = load_repository_training_files(LABEL_COLUMN, force_refresh=False)
2201
+ remote = refresh_remote_browser(force_refresh=False)
2202
+ return (*local, *remote)
2203
+
2204
+ demo.load(
2205
+ _initialise_dataset,
2206
+ inputs=None,
2207
+ outputs=[
2208
+ training_files_state,
2209
+ training_files_summary,
2210
+ label_input,
2211
+ dataset_info,
2212
+ year_selector,
2213
+ month_selector,
2214
+ day_selector,
2215
+ available_files,
2216
+ repo_status,
2217
+ ],
2218
+ queue=False,
2219
+ )
2220
+
2221
+ return demo
2222
+
2223
+
2224
+ # --------------------------------------------------------------------------------------
2225
+ # Launch helpers
2226
+ # --------------------------------------------------------------------------------------
2227
+
2228
+ def resolve_server_port() -> int:
2229
+ for env_var in ("PORT", "GRADIO_SERVER_PORT"):
2230
+ value = os.environ.get(env_var)
2231
+ if value:
2232
+ try:
2233
+ return int(value)
2234
+ except ValueError:
2235
+ print(f"Ignoring invalid port value from {env_var}: {value}")
2236
+ return 7860
2237
+
2238
+
2239
+ def main():
2240
+ print("Building Gradio interface...")
2241
+ try:
2242
+ demo = build_interface()
2243
+ print("Interface built successfully")
2244
+ except Exception as e:
2245
+ print(f"Failed to build interface: {e}")
2246
+ import traceback
2247
+ traceback.print_exc()
2248
+ return
2249
+
2250
+ print("Setting up queue...")
2251
+ try:
2252
+ demo.queue(max_size=QUEUE_MAX_SIZE)
2253
+ print("Queue configured")
2254
+ except Exception as e:
2255
+ print(f"Failed to configure queue: {e}")
2256
+
2257
+ try:
2258
+ port = resolve_server_port()
2259
+ print(f"Launching Gradio app on port {port}")
2260
+ demo.launch(server_name="0.0.0.0", server_port=port, show_error=True)
2261
+ except OSError as exc:
2262
+ print("Failed to launch on requested port:", exc)
2263
+ try:
2264
+ demo.launch(server_name="0.0.0.0", show_error=True)
2265
+ except Exception as e:
2266
+ print(f"Failed to launch completely: {e}")
2267
+ except Exception as e:
2268
+ print(f"Unexpected launch error: {e}")
2269
+ import traceback
2270
+ traceback.print_exc()
2271
+
2272
+
2273
+ if __name__ == "__main__":
2274
+ print("="*50)
2275
+ print("PMU Fault Classification App Starting")
2276
+ print(f"Python version: {os.sys.version}")
2277
+ print(f"Working directory: {os.getcwd()}")
2278
+ print(f"HUB_REPO: {HUB_REPO}")
2279
+ print(f"Model available: {MODEL is not None}")
2280
+ print(f"Scaler available: {SCALER is not None}")
2281
+ print("="*50)
2282
+ main()
.history/app_20251009231310.py ADDED
@@ -0,0 +1,2402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Gradio front-end for Fault_Classification_PMU_Data models.
2
+
3
+ The application loads a CNN-LSTM model (and accompanying scaler/metadata)
4
+ produced by ``fault_classification_pmu.py`` and exposes a streamlined
5
+ prediction interface optimised for Hugging Face Spaces deployment. It supports
6
+ raw PMU time-series CSV uploads as well as manual comma separated feature
7
+ vectors.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import json
13
+ import os
14
+ import shutil
15
+
16
+ os.environ.setdefault("CUDA_VISIBLE_DEVICES", "-1")
17
+ os.environ.setdefault("TF_CPP_MIN_LOG_LEVEL", "2")
18
+ os.environ.setdefault("TF_ENABLE_ONEDNN_OPTS", "0")
19
+
20
+ import re
21
+ from pathlib import Path
22
+ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
23
+
24
+ import gradio as gr
25
+ import joblib
26
+ import numpy as np
27
+ import pandas as pd
28
+ import requests
29
+ from huggingface_hub import hf_hub_download
30
+ from tensorflow.keras.models import load_model
31
+
32
+ from fault_classification_pmu import (
33
+ DEFAULT_FEATURE_COLUMNS as TRAINING_DEFAULT_FEATURE_COLUMNS,
34
+ LABEL_GUESS_CANDIDATES as TRAINING_LABEL_GUESSES,
35
+ train_from_dataframe,
36
+ )
37
+
38
+ # --------------------------------------------------------------------------------------
39
+ # Configuration
40
+ # --------------------------------------------------------------------------------------
41
+ DEFAULT_FEATURE_COLUMNS: List[str] = list(TRAINING_DEFAULT_FEATURE_COLUMNS)
42
+ DEFAULT_SEQUENCE_LENGTH = 32
43
+ DEFAULT_STRIDE = 4
44
+
45
+ LOCAL_MODEL_FILE = os.environ.get("PMU_MODEL_FILE", "pmu_cnn_lstm_model.keras")
46
+ LOCAL_SCALER_FILE = os.environ.get("PMU_SCALER_FILE", "pmu_feature_scaler.pkl")
47
+ LOCAL_METADATA_FILE = os.environ.get("PMU_METADATA_FILE", "pmu_metadata.json")
48
+
49
+ MODEL_OUTPUT_DIR = Path(os.environ.get("PMU_MODEL_DIR", "model")).resolve()
50
+ MODEL_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
51
+
52
+ HUB_REPO = os.environ.get("PMU_HUB_REPO", "")
53
+ HUB_MODEL_FILENAME = os.environ.get("PMU_HUB_MODEL_FILENAME", LOCAL_MODEL_FILE)
54
+ HUB_SCALER_FILENAME = os.environ.get("PMU_HUB_SCALER_FILENAME", LOCAL_SCALER_FILE)
55
+ HUB_METADATA_FILENAME = os.environ.get("PMU_HUB_METADATA_FILENAME", LOCAL_METADATA_FILE)
56
+
57
+ ENV_MODEL_PATH = "PMU_MODEL_PATH"
58
+ ENV_SCALER_PATH = "PMU_SCALER_PATH"
59
+ ENV_METADATA_PATH = "PMU_METADATA_PATH"
60
+
61
+ # --------------------------------------------------------------------------------------
62
+ # Utility functions for loading artifacts
63
+ # --------------------------------------------------------------------------------------
64
+
65
+
66
+ def download_from_hub(filename: str) -> Optional[Path]:
67
+ if not HUB_REPO or not filename:
68
+ return None
69
+ try:
70
+ print(f"Downloading {filename} from {HUB_REPO} ...")
71
+ # Add timeout to prevent hanging
72
+ path = hf_hub_download(repo_id=HUB_REPO, filename=filename)
73
+ print("Downloaded", path)
74
+ return Path(path)
75
+ except Exception as exc: # pragma: no cover - logging convenience
76
+ print("Failed to download", filename, "from", HUB_REPO, ":", exc)
77
+ print("Continuing without pre-trained model...")
78
+ return None
79
+
80
+
81
+ def resolve_artifact(
82
+ local_name: str, env_var: str, hub_filename: str
83
+ ) -> Optional[Path]:
84
+ print(f"Resolving artifact: {local_name}, env: {env_var}, hub: {hub_filename}")
85
+ candidates = [Path(local_name)] if local_name else []
86
+ if local_name:
87
+ candidates.append(MODEL_OUTPUT_DIR / Path(local_name).name)
88
+ env_value = os.environ.get(env_var)
89
+ if env_value:
90
+ candidates.append(Path(env_value))
91
+
92
+ for candidate in candidates:
93
+ if candidate and candidate.exists():
94
+ print(f"Found local artifact: {candidate}")
95
+ return candidate
96
+
97
+ print(f"No local artifacts found, checking hub...")
98
+ # Only try to download if we have a hub repo configured
99
+ if HUB_REPO:
100
+ return download_from_hub(hub_filename)
101
+ else:
102
+ print("No HUB_REPO configured, skipping download")
103
+ return None
104
+
105
+
106
+ def load_metadata(path: Optional[Path]) -> Dict:
107
+ if path and path.exists():
108
+ try:
109
+ return json.loads(path.read_text())
110
+ except Exception as exc: # pragma: no cover - metadata parsing errors
111
+ print("Failed to read metadata", path, exc)
112
+ return {}
113
+
114
+
115
+ def try_load_scaler(path: Optional[Path]):
116
+ if not path:
117
+ return None
118
+ try:
119
+ scaler = joblib.load(path)
120
+ print("Loaded scaler from", path)
121
+ return scaler
122
+ except Exception as exc:
123
+ print("Failed to load scaler", path, exc)
124
+ return None
125
+
126
+
127
+ # Initialize paths with error handling
128
+ print("Starting application initialization...")
129
+ try:
130
+ MODEL_PATH = resolve_artifact(LOCAL_MODEL_FILE, ENV_MODEL_PATH, HUB_MODEL_FILENAME)
131
+ print(f"Model path resolved: {MODEL_PATH}")
132
+ except Exception as e:
133
+ print(f"Model path resolution failed: {e}")
134
+ MODEL_PATH = None
135
+
136
+ try:
137
+ SCALER_PATH = resolve_artifact(
138
+ LOCAL_SCALER_FILE, ENV_SCALER_PATH, HUB_SCALER_FILENAME
139
+ )
140
+ print(f"Scaler path resolved: {SCALER_PATH}")
141
+ except Exception as e:
142
+ print(f"Scaler path resolution failed: {e}")
143
+ SCALER_PATH = None
144
+
145
+ try:
146
+ METADATA_PATH = resolve_artifact(
147
+ LOCAL_METADATA_FILE, ENV_METADATA_PATH, HUB_METADATA_FILENAME
148
+ )
149
+ print(f"Metadata path resolved: {METADATA_PATH}")
150
+ except Exception as e:
151
+ print(f"Metadata path resolution failed: {e}")
152
+ METADATA_PATH = None
153
+
154
+ try:
155
+ METADATA = load_metadata(METADATA_PATH)
156
+ print(f"Metadata loaded: {len(METADATA)} entries")
157
+ except Exception as e:
158
+ print(f"Metadata loading failed: {e}")
159
+ METADATA = {}
160
+
161
+ # Queuing configuration
162
+ QUEUE_MAX_SIZE = 32
163
+ # Apply a small per-event concurrency limit to avoid relying on the deprecated
164
+ # ``concurrency_count`` parameter when enabling Gradio's request queue.
165
+ EVENT_CONCURRENCY_LIMIT = 2
166
+
167
+
168
+ def try_load_model(path: Optional[Path], model_type: str, model_format: str):
169
+ if not path:
170
+ return None
171
+ try:
172
+ if model_type == "svm" or model_format == "joblib":
173
+ model = joblib.load(path)
174
+ else:
175
+ model = load_model(path)
176
+ print("Loaded model from", path)
177
+ return model
178
+ except Exception as exc: # pragma: no cover - runtime diagnostics
179
+ print("Failed to load model", path, exc)
180
+ return None
181
+
182
+
183
+ FEATURE_COLUMNS: List[str] = list(DEFAULT_FEATURE_COLUMNS)
184
+ LABEL_CLASSES: List[str] = []
185
+ LABEL_COLUMN: str = "Fault"
186
+ SEQUENCE_LENGTH: int = DEFAULT_SEQUENCE_LENGTH
187
+ DEFAULT_WINDOW_STRIDE: int = DEFAULT_STRIDE
188
+ MODEL_TYPE: str = "cnn_lstm"
189
+ MODEL_FORMAT: str = "keras"
190
+
191
+
192
+ def _model_output_path(filename: str) -> str:
193
+ return str(MODEL_OUTPUT_DIR / Path(filename).name)
194
+
195
+
196
+ MODEL_FILENAME_BY_TYPE: Dict[str, str] = {
197
+ "cnn_lstm": Path(LOCAL_MODEL_FILE).name,
198
+ "tcn": "pmu_tcn_model.keras",
199
+ "svm": "pmu_svm_model.joblib",
200
+ }
201
+
202
+ REQUIRED_PMU_COLUMNS: Tuple[str, ...] = tuple(DEFAULT_FEATURE_COLUMNS)
203
+ TRAINING_UPLOAD_DIR = Path(
204
+ os.environ.get("PMU_TRAINING_UPLOAD_DIR", "training_uploads")
205
+ )
206
+ TRAINING_UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
207
+
208
+ TRAINING_DATA_REPO = os.environ.get(
209
+ "PMU_TRAINING_DATA_REPO", "VincentCroft/ThesisModelData"
210
+ )
211
+ TRAINING_DATA_BRANCH = os.environ.get("PMU_TRAINING_DATA_BRANCH", "main")
212
+ TRAINING_DATA_DIR = Path(os.environ.get("PMU_TRAINING_DATA_DIR", "training_dataset"))
213
+ TRAINING_DATA_DIR.mkdir(parents=True, exist_ok=True)
214
+
215
+ GITHUB_CONTENT_CACHE: Dict[str, List[Dict[str, Any]]] = {}
216
+
217
+
218
+ APP_CSS = """
219
+ #available-files-section {
220
+ position: relative;
221
+ display: flex;
222
+ flex-direction: column;
223
+ gap: 0.75rem;
224
+ border-radius: 0.75rem;
225
+ }
226
+
227
+ #available-files-grid {
228
+ position: relative;
229
+ overflow: visible;
230
+ }
231
+
232
+ #available-files-grid .form {
233
+ position: relative;
234
+ min-height: 16rem;
235
+ }
236
+
237
+ #available-files-section:has(.gradio-loading) {
238
+ isolation: isolate;
239
+ }
240
+
241
+ #available-files-grid .wrap {
242
+ display: grid;
243
+ grid-template-columns: repeat(4, minmax(0, 1fr));
244
+ gap: 0.5rem;
245
+ max-height: 24rem;
246
+ min-height: 16rem;
247
+ overflow-y: auto;
248
+ padding-right: 0.25rem;
249
+ }
250
+
251
+ #available-files-grid .wrap > div {
252
+ min-width: 0;
253
+ }
254
+
255
+ #available-files-grid .wrap label {
256
+ margin: 0;
257
+ display: flex;
258
+ align-items: center;
259
+ padding: 0.45rem 0.65rem;
260
+ border-radius: 0.65rem;
261
+ background-color: rgba(255, 255, 255, 0.05);
262
+ border: 1px solid rgba(255, 255, 255, 0.08);
263
+ transition: background-color 0.2s ease, border-color 0.2s ease;
264
+ min-height: 2.5rem;
265
+ }
266
+
267
+ #available-files-grid .wrap label:hover {
268
+ background-color: rgba(90, 200, 250, 0.16);
269
+ border-color: rgba(90, 200, 250, 0.4);
270
+ }
271
+
272
+ #available-files-grid .wrap label span {
273
+ overflow: hidden;
274
+ text-overflow: ellipsis;
275
+ white-space: nowrap;
276
+ }
277
+
278
+ #available-files-grid .gradio-loading {
279
+ position: absolute;
280
+ inset: 0;
281
+ width: auto;
282
+ height: auto;
283
+ min-height: 100%;
284
+ display: flex;
285
+ align-items: center;
286
+ justify-content: center;
287
+ background: rgba(10, 14, 23, 0.72);
288
+ border-radius: 0.75rem;
289
+ z-index: 10;
290
+ padding: 1.5rem;
291
+ pointer-events: auto;
292
+ }
293
+
294
+ #available-files-grid .gradio-loading > * {
295
+ width: 100%;
296
+ }
297
+
298
+ #available-files-grid .gradio-loading progress,
299
+ #available-files-grid .gradio-loading .progress-bar,
300
+ #available-files-grid .gradio-loading .loading-progress,
301
+ #available-files-grid .gradio-loading [role="progressbar"],
302
+ #available-files-grid .gradio-loading .wrap,
303
+ #available-files-grid .gradio-loading .inner {
304
+ width: 100% !important;
305
+ max-width: none !important;
306
+ }
307
+
308
+ #available-files-grid .gradio-loading .status,
309
+ #available-files-grid .gradio-loading .message,
310
+ #available-files-grid .gradio-loading .label {
311
+ text-align: center;
312
+ }
313
+
314
+ #date-browser-row {
315
+ gap: 0.75rem;
316
+ }
317
+
318
+ #date-browser-row .date-browser-column {
319
+ flex: 1 1 0%;
320
+ min-width: 0;
321
+ }
322
+
323
+ #date-browser-row .date-browser-column > .gradio-dropdown,
324
+ #date-browser-row .date-browser-column > .gradio-button {
325
+ width: 100%;
326
+ }
327
+
328
+ #date-browser-row .date-browser-column > .gradio-dropdown > div {
329
+ width: 100%;
330
+ }
331
+
332
+ #date-browser-row .date-browser-column .gradio-button {
333
+ justify-content: center;
334
+ }
335
+
336
+ #training-files-summary textarea {
337
+ max-height: 12rem;
338
+ overflow-y: auto;
339
+ }
340
+
341
+ #download-selected-button {
342
+ width: 100%;
343
+ position: relative;
344
+ z-index: 0;
345
+ }
346
+
347
+ #download-selected-button .gradio-button {
348
+ width: 100%;
349
+ justify-content: center;
350
+ }
351
+
352
+ #artifact-download-row {
353
+ gap: 0.75rem;
354
+ }
355
+
356
+ #artifact-download-row .artifact-download-button {
357
+ flex: 1 1 0%;
358
+ min-width: 0;
359
+ }
360
+
361
+ #artifact-download-row .artifact-download-button .gradio-button {
362
+ width: 100%;
363
+ justify-content: center;
364
+ }
365
+ """
366
+
367
+
368
+ def _github_cache_key(path: str) -> str:
369
+ return path or "__root__"
370
+
371
+
372
+ def _github_api_url(path: str) -> str:
373
+ clean_path = path.strip("/")
374
+ base = f"https://api.github.com/repos/{TRAINING_DATA_REPO}/contents"
375
+ if clean_path:
376
+ return f"{base}/{clean_path}?ref={TRAINING_DATA_BRANCH}"
377
+ return f"{base}?ref={TRAINING_DATA_BRANCH}"
378
+
379
+
380
+ def list_remote_directory(
381
+ path: str = "", *, force_refresh: bool = False
382
+ ) -> List[Dict[str, Any]]:
383
+ key = _github_cache_key(path)
384
+ if not force_refresh and key in GITHUB_CONTENT_CACHE:
385
+ return GITHUB_CONTENT_CACHE[key]
386
+
387
+ url = _github_api_url(path)
388
+ response = requests.get(url, timeout=30)
389
+ if response.status_code != 200:
390
+ raise RuntimeError(
391
+ f"GitHub API request failed for `{path or '.'}` (status {response.status_code})."
392
+ )
393
+
394
+ payload = response.json()
395
+ if not isinstance(payload, list):
396
+ raise RuntimeError(
397
+ "Unexpected GitHub API payload. Expected a directory listing."
398
+ )
399
+
400
+ GITHUB_CONTENT_CACHE[key] = payload
401
+ return payload
402
+
403
+
404
+ def list_remote_years(force_refresh: bool = False) -> List[str]:
405
+ entries = list_remote_directory("", force_refresh=force_refresh)
406
+ years = [item["name"] for item in entries if item.get("type") == "dir"]
407
+ return sorted(years)
408
+
409
+
410
+ def list_remote_months(year: str, *, force_refresh: bool = False) -> List[str]:
411
+ if not year:
412
+ return []
413
+ entries = list_remote_directory(year, force_refresh=force_refresh)
414
+ months = [item["name"] for item in entries if item.get("type") == "dir"]
415
+ return sorted(months)
416
+
417
+
418
+ def list_remote_days(
419
+ year: str, month: str, *, force_refresh: bool = False
420
+ ) -> List[str]:
421
+ if not year or not month:
422
+ return []
423
+ entries = list_remote_directory(f"{year}/{month}", force_refresh=force_refresh)
424
+ days = [item["name"] for item in entries if item.get("type") == "dir"]
425
+ return sorted(days)
426
+
427
+
428
+ def list_remote_files(
429
+ year: str, month: str, day: str, *, force_refresh: bool = False
430
+ ) -> List[str]:
431
+ if not year or not month or not day:
432
+ return []
433
+ entries = list_remote_directory(
434
+ f"{year}/{month}/{day}", force_refresh=force_refresh
435
+ )
436
+ files = [item["name"] for item in entries if item.get("type") == "file"]
437
+ return sorted(files)
438
+
439
+
440
+ def download_repository_file(year: str, month: str, day: str, filename: str) -> Path:
441
+ if not filename:
442
+ raise ValueError("Filename cannot be empty when downloading repository data.")
443
+
444
+ relative_parts = [part for part in (year, month, day, filename) if part]
445
+ if len(relative_parts) < 4:
446
+ raise ValueError("Provide year, month, day, and filename to download a CSV.")
447
+
448
+ relative_path = "/".join(relative_parts)
449
+ raw_url = (
450
+ f"https://raw.githubusercontent.com/{TRAINING_DATA_REPO}/"
451
+ f"{TRAINING_DATA_BRANCH}/{relative_path}"
452
+ )
453
+
454
+ response = requests.get(raw_url, stream=True, timeout=120)
455
+ if response.status_code != 200:
456
+ raise RuntimeError(
457
+ f"Failed to download `{relative_path}` (status {response.status_code})."
458
+ )
459
+
460
+ target_dir = TRAINING_DATA_DIR.joinpath(year, month, day)
461
+ target_dir.mkdir(parents=True, exist_ok=True)
462
+ target_path = target_dir / filename
463
+
464
+ with open(target_path, "wb") as handle:
465
+ for chunk in response.iter_content(chunk_size=1 << 20):
466
+ if chunk:
467
+ handle.write(chunk)
468
+
469
+ return target_path
470
+
471
+
472
+ def _normalise_header(name: str) -> str:
473
+ return str(name).strip().lower()
474
+
475
+
476
+ def guess_label_from_columns(
477
+ columns: Sequence[str], preferred: Optional[str] = None
478
+ ) -> Optional[str]:
479
+ if not columns:
480
+ return preferred
481
+
482
+ lookup = {_normalise_header(col): str(col) for col in columns}
483
+
484
+ if preferred:
485
+ preferred_stripped = preferred.strip()
486
+ for col in columns:
487
+ if str(col).strip() == preferred_stripped:
488
+ return str(col)
489
+ preferred_norm = _normalise_header(preferred)
490
+ if preferred_norm in lookup:
491
+ return lookup[preferred_norm]
492
+
493
+ for guess in TRAINING_LABEL_GUESSES:
494
+ guess_norm = _normalise_header(guess)
495
+ if guess_norm in lookup:
496
+ return lookup[guess_norm]
497
+
498
+ for col in columns:
499
+ if _normalise_header(col).startswith("fault"):
500
+ return str(col)
501
+
502
+ return str(columns[0])
503
+
504
+
505
+ def summarise_training_files(paths: Sequence[str], notes: Sequence[str]) -> str:
506
+ lines = [Path(path).name for path in paths]
507
+ lines.extend(notes)
508
+ return "\n".join(lines) if lines else "No training files available."
509
+
510
+
511
+ def read_training_status(status_file_path: str) -> str:
512
+ """Read the current training status from file."""
513
+ try:
514
+ if Path(status_file_path).exists():
515
+ with open(status_file_path, "r") as f:
516
+ return f.read().strip()
517
+ except Exception:
518
+ pass
519
+ return "Training status unavailable"
520
+
521
+
522
+ def _persist_uploaded_file(file_obj) -> Optional[Path]:
523
+ if file_obj is None:
524
+ return None
525
+
526
+ if isinstance(file_obj, (str, Path)):
527
+ source = Path(file_obj)
528
+ original_name = source.name
529
+ else:
530
+ source = Path(getattr(file_obj, "name", "") or getattr(file_obj, "path", ""))
531
+ original_name = getattr(file_obj, "orig_name", source.name) or source.name
532
+ if not source or not source.exists():
533
+ return None
534
+
535
+ original_name = Path(original_name).name or source.name
536
+
537
+ base_path = Path(original_name)
538
+ destination = TRAINING_UPLOAD_DIR / base_path.name
539
+ counter = 1
540
+ while destination.exists():
541
+ suffix = base_path.suffix or ".csv"
542
+ destination = TRAINING_UPLOAD_DIR / f"{base_path.stem}_{counter}{suffix}"
543
+ counter += 1
544
+
545
+ shutil.copy2(source, destination)
546
+ return destination
547
+
548
+
549
+ def prepare_training_paths(
550
+ paths: Sequence[str], current_label: str, cleanup_missing: bool = False
551
+ ):
552
+ valid_paths: List[str] = []
553
+ notes: List[str] = []
554
+ columns_map: Dict[str, str] = {}
555
+ for path in paths:
556
+ try:
557
+ df = load_measurement_csv(path)
558
+ except Exception as exc: # pragma: no cover - user file diagnostics
559
+ notes.append(f"⚠️ Skipped {Path(path).name}: {exc}")
560
+ if cleanup_missing:
561
+ try:
562
+ Path(path).unlink(missing_ok=True)
563
+ except Exception:
564
+ pass
565
+ continue
566
+ valid_paths.append(str(path))
567
+ for col in df.columns:
568
+ columns_map[_normalise_header(col)] = str(col)
569
+
570
+ summary = summarise_training_files(valid_paths, notes)
571
+ preferred = current_label or LABEL_COLUMN
572
+ dropdown_choices = (
573
+ sorted(columns_map.values()) if columns_map else [preferred or LABEL_COLUMN]
574
+ )
575
+ guessed = guess_label_from_columns(dropdown_choices, preferred)
576
+ dropdown_value = guessed or preferred or LABEL_COLUMN
577
+
578
+ return (
579
+ valid_paths,
580
+ summary,
581
+ gr.update(choices=dropdown_choices, value=dropdown_value),
582
+ )
583
+
584
+
585
+ def append_training_files(new_files, existing_paths: Sequence[str], current_label: str):
586
+ if isinstance(existing_paths, (str, Path)):
587
+ paths: List[str] = [str(existing_paths)]
588
+ elif existing_paths is None:
589
+ paths = []
590
+ else:
591
+ paths = list(existing_paths)
592
+ if new_files:
593
+ for file in new_files:
594
+ persisted = _persist_uploaded_file(file)
595
+ if persisted is None:
596
+ continue
597
+ path_str = str(persisted)
598
+ if path_str not in paths:
599
+ paths.append(path_str)
600
+
601
+ return prepare_training_paths(paths, current_label, cleanup_missing=True)
602
+
603
+
604
+ def load_repository_training_files(current_label: str, force_refresh: bool = False):
605
+ if force_refresh:
606
+ # Clearing the cache is enough because downloads are now on-demand.
607
+ for cached in list(TRAINING_DATA_DIR.glob("*")):
608
+ # On refresh we keep previously downloaded files; no deletion required.
609
+ # The flag triggers downstream UI updates only.
610
+ break
611
+
612
+ csv_paths = sorted(
613
+ str(path) for path in TRAINING_DATA_DIR.rglob("*.csv") if path.is_file()
614
+ )
615
+ if not csv_paths:
616
+ message = (
617
+ "No local database CSVs are available yet. Use the database browser "
618
+ "below to download specific days before training."
619
+ )
620
+ default_label = current_label or LABEL_COLUMN or "Fault"
621
+ return (
622
+ [],
623
+ message,
624
+ gr.update(choices=[default_label], value=default_label),
625
+ message,
626
+ )
627
+
628
+ valid_paths, summary, label_update = prepare_training_paths(
629
+ csv_paths, current_label, cleanup_missing=False
630
+ )
631
+
632
+ info = (
633
+ f"Ready with {len(valid_paths)} CSV file(s) cached locally under "
634
+ f"the database cache `{TRAINING_DATA_DIR}`."
635
+ )
636
+
637
+ return valid_paths, summary, label_update, info
638
+
639
+
640
+ def refresh_remote_browser(force_refresh: bool = False):
641
+ if force_refresh:
642
+ GITHUB_CONTENT_CACHE.clear()
643
+ try:
644
+ years = list_remote_years(force_refresh=force_refresh)
645
+ if years:
646
+ message = "Select a year, month, and day to list available CSV files."
647
+ else:
648
+ message = (
649
+ "⚠️ No directories were found in the database root. Verify the upstream "
650
+ "structure."
651
+ )
652
+ return (
653
+ gr.update(choices=years, value=None),
654
+ gr.update(choices=[], value=None),
655
+ gr.update(choices=[], value=None),
656
+ gr.update(choices=[], value=[]),
657
+ message,
658
+ )
659
+ except Exception as exc:
660
+ return (
661
+ gr.update(choices=[], value=None),
662
+ gr.update(choices=[], value=None),
663
+ gr.update(choices=[], value=None),
664
+ gr.update(choices=[], value=[]),
665
+ f"⚠️ Failed to query database: {exc}",
666
+ )
667
+
668
+
669
+ def on_year_change(year: Optional[str]):
670
+ if not year:
671
+ return (
672
+ gr.update(choices=[], value=None),
673
+ gr.update(choices=[], value=None),
674
+ gr.update(choices=[], value=[]),
675
+ "Select a year to continue.",
676
+ )
677
+ try:
678
+ months = list_remote_months(year)
679
+ message = (
680
+ f"Year `{year}` selected. Choose a month to drill down."
681
+ if months
682
+ else f"⚠️ No months available under `{year}`."
683
+ )
684
+ return (
685
+ gr.update(choices=months, value=None),
686
+ gr.update(choices=[], value=None),
687
+ gr.update(choices=[], value=[]),
688
+ message,
689
+ )
690
+ except Exception as exc:
691
+ return (
692
+ gr.update(choices=[], value=None),
693
+ gr.update(choices=[], value=None),
694
+ gr.update(choices=[], value=[]),
695
+ f"⚠️ Failed to list months: {exc}",
696
+ )
697
+
698
+
699
+ def on_month_change(year: Optional[str], month: Optional[str]):
700
+ if not year or not month:
701
+ return (
702
+ gr.update(choices=[], value=None),
703
+ gr.update(choices=[], value=[]),
704
+ "Select a month to continue.",
705
+ )
706
+ try:
707
+ days = list_remote_days(year, month)
708
+ message = (
709
+ f"Month `{year}/{month}` ready. Pick a day to view files."
710
+ if days
711
+ else f"⚠️ No day folders found under `{year}/{month}`."
712
+ )
713
+ return (
714
+ gr.update(choices=days, value=None),
715
+ gr.update(choices=[], value=[]),
716
+ message,
717
+ )
718
+ except Exception as exc:
719
+ return (
720
+ gr.update(choices=[], value=None),
721
+ gr.update(choices=[], value=[]),
722
+ f"⚠️ Failed to list days: {exc}",
723
+ )
724
+
725
+
726
+ def on_day_change(year: Optional[str], month: Optional[str], day: Optional[str]):
727
+ if not year or not month or not day:
728
+ return (
729
+ gr.update(choices=[], value=[]),
730
+ "Select a day to load file names.",
731
+ )
732
+ try:
733
+ files = list_remote_files(year, month, day)
734
+ message = (
735
+ f"{len(files)} file(s) available for `{year}/{month}/{day}`."
736
+ if files
737
+ else f"⚠️ No CSV files found under `{year}/{month}/{day}`."
738
+ )
739
+ return (
740
+ gr.update(choices=files, value=[]),
741
+ message,
742
+ )
743
+ except Exception as exc:
744
+ return (
745
+ gr.update(choices=[], value=[]),
746
+ f"⚠️ Failed to list files: {exc}",
747
+ )
748
+
749
+
750
+ def download_selected_files(
751
+ year: Optional[str],
752
+ month: Optional[str],
753
+ day: Optional[str],
754
+ filenames: Sequence[str],
755
+ current_label: str,
756
+ ):
757
+ if not filenames:
758
+ message = "Select at least one CSV before downloading."
759
+ local = load_repository_training_files(current_label)
760
+ return (*local, gr.update(), message)
761
+
762
+ success: List[str] = []
763
+ notes: List[str] = []
764
+ for filename in filenames:
765
+ try:
766
+ path = download_repository_file(
767
+ year or "", month or "", day or "", filename
768
+ )
769
+ success.append(str(path))
770
+ except Exception as exc:
771
+ notes.append(f"⚠️ {filename}: {exc}")
772
+
773
+ local = load_repository_training_files(current_label)
774
+
775
+ message_lines = []
776
+ if success:
777
+ message_lines.append(
778
+ f"Downloaded {len(success)} file(s) to the database cache `{TRAINING_DATA_DIR}`."
779
+ )
780
+ if notes:
781
+ message_lines.extend(notes)
782
+ if not message_lines:
783
+ message_lines.append("No files were downloaded.")
784
+
785
+ return (*local, gr.update(value=[]), "\n".join(message_lines))
786
+
787
+
788
+ def download_day_bundle(
789
+ year: Optional[str],
790
+ month: Optional[str],
791
+ day: Optional[str],
792
+ current_label: str,
793
+ ):
794
+ if not (year and month and day):
795
+ local = load_repository_training_files(current_label)
796
+ return (
797
+ *local,
798
+ gr.update(),
799
+ "Select a year, month, and day before downloading an entire day.",
800
+ )
801
+
802
+ try:
803
+ files = list_remote_files(year, month, day)
804
+ except Exception as exc:
805
+ local = load_repository_training_files(current_label)
806
+ return (
807
+ *local,
808
+ gr.update(),
809
+ f"⚠️ Failed to list CSVs for `{year}/{month}/{day}`: {exc}",
810
+ )
811
+
812
+ if not files:
813
+ local = load_repository_training_files(current_label)
814
+ return (
815
+ *local,
816
+ gr.update(),
817
+ f"No CSV files were found for `{year}/{month}/{day}`.",
818
+ )
819
+
820
+ result = list(download_selected_files(year, month, day, files, current_label))
821
+ result[-1] = (
822
+ f"Downloaded all {len(files)} CSV file(s) for `{year}/{month}/{day}`.\n"
823
+ f"{result[-1]}"
824
+ )
825
+ return tuple(result)
826
+
827
+
828
+ def download_month_bundle(
829
+ year: Optional[str], month: Optional[str], current_label: str
830
+ ):
831
+ if not (year and month):
832
+ local = load_repository_training_files(current_label)
833
+ return (
834
+ *local,
835
+ gr.update(),
836
+ "Select a year and month before downloading an entire month.",
837
+ )
838
+
839
+ try:
840
+ days = list_remote_days(year, month)
841
+ except Exception as exc:
842
+ local = load_repository_training_files(current_label)
843
+ return (
844
+ *local,
845
+ gr.update(),
846
+ f"⚠️ Failed to enumerate days for `{year}/{month}`: {exc}",
847
+ )
848
+
849
+ if not days:
850
+ local = load_repository_training_files(current_label)
851
+ return (
852
+ *local,
853
+ gr.update(),
854
+ f"No day folders were found for `{year}/{month}`.",
855
+ )
856
+
857
+ downloaded = 0
858
+ notes: List[str] = []
859
+ for day in days:
860
+ try:
861
+ files = list_remote_files(year, month, day)
862
+ except Exception as exc:
863
+ notes.append(f"⚠️ Failed to list `{year}/{month}/{day}`: {exc}")
864
+ continue
865
+ if not files:
866
+ notes.append(f"⚠️ No CSV files in `{year}/{month}/{day}`.")
867
+ continue
868
+ for filename in files:
869
+ try:
870
+ download_repository_file(year, month, day, filename)
871
+ downloaded += 1
872
+ except Exception as exc:
873
+ notes.append(f"⚠️ {year}/{month}/{day}/{filename}: {exc}")
874
+
875
+ local = load_repository_training_files(current_label)
876
+ message_lines = []
877
+ if downloaded:
878
+ message_lines.append(
879
+ f"Downloaded {downloaded} CSV file(s) for `{year}/{month}` into the "
880
+ f"database cache `{TRAINING_DATA_DIR}`."
881
+ )
882
+ message_lines.extend(notes)
883
+ if not message_lines:
884
+ message_lines.append("No files were downloaded.")
885
+
886
+ return (*local, gr.update(value=[]), "\n".join(message_lines))
887
+
888
+
889
+ def download_year_bundle(year: Optional[str], current_label: str):
890
+ if not year:
891
+ local = load_repository_training_files(current_label)
892
+ return (
893
+ *local,
894
+ gr.update(),
895
+ "Select a year before downloading an entire year of CSVs.",
896
+ )
897
+
898
+ try:
899
+ months = list_remote_months(year)
900
+ except Exception as exc:
901
+ local = load_repository_training_files(current_label)
902
+ return (
903
+ *local,
904
+ gr.update(),
905
+ f"⚠️ Failed to enumerate months for `{year}`: {exc}",
906
+ )
907
+
908
+ if not months:
909
+ local = load_repository_training_files(current_label)
910
+ return (
911
+ *local,
912
+ gr.update(),
913
+ f"No month folders were found for `{year}`.",
914
+ )
915
+
916
+ downloaded = 0
917
+ notes: List[str] = []
918
+ for month in months:
919
+ try:
920
+ days = list_remote_days(year, month)
921
+ except Exception as exc:
922
+ notes.append(f"⚠️ Failed to list `{year}/{month}`: {exc}")
923
+ continue
924
+ if not days:
925
+ notes.append(f"⚠️ No day folders in `{year}/{month}`.")
926
+ continue
927
+ for day in days:
928
+ try:
929
+ files = list_remote_files(year, month, day)
930
+ except Exception as exc:
931
+ notes.append(f"⚠️ Failed to list `{year}/{month}/{day}`: {exc}")
932
+ continue
933
+ if not files:
934
+ notes.append(f"⚠️ No CSV files in `{year}/{month}/{day}`.")
935
+ continue
936
+ for filename in files:
937
+ try:
938
+ download_repository_file(year, month, day, filename)
939
+ downloaded += 1
940
+ except Exception as exc:
941
+ notes.append(f"⚠️ {year}/{month}/{day}/{filename}: {exc}")
942
+
943
+ local = load_repository_training_files(current_label)
944
+ message_lines = []
945
+ if downloaded:
946
+ message_lines.append(
947
+ f"Downloaded {downloaded} CSV file(s) for `{year}` into the "
948
+ f"database cache `{TRAINING_DATA_DIR}`."
949
+ )
950
+ message_lines.extend(notes)
951
+ if not message_lines:
952
+ message_lines.append("No files were downloaded.")
953
+
954
+ return (*local, gr.update(value=[]), "\n".join(message_lines))
955
+
956
+
957
+ def clear_downloaded_cache(current_label: str):
958
+ status_message = ""
959
+ try:
960
+ if TRAINING_DATA_DIR.exists():
961
+ shutil.rmtree(TRAINING_DATA_DIR)
962
+ TRAINING_DATA_DIR.mkdir(parents=True, exist_ok=True)
963
+ status_message = (
964
+ f"Cleared all downloaded CSVs from database cache `{TRAINING_DATA_DIR}`."
965
+ )
966
+ except Exception as exc:
967
+ status_message = f"⚠️ Failed to clear database cache: {exc}"
968
+
969
+ local = load_repository_training_files(current_label, force_refresh=True)
970
+ remote = list(refresh_remote_browser(force_refresh=False))
971
+ if status_message:
972
+ previous = remote[-1]
973
+ if isinstance(previous, str) and previous:
974
+ remote[-1] = f"{status_message}\n{previous}"
975
+ else:
976
+ remote[-1] = status_message
977
+
978
+ return (*local, *remote)
979
+
980
+
981
+ def normalise_output_directory(directory: Optional[str]) -> Path:
982
+ base = Path(directory or MODEL_OUTPUT_DIR)
983
+ base = base.expanduser()
984
+ if not base.is_absolute():
985
+ base = (Path.cwd() / base).resolve()
986
+ return base
987
+
988
+
989
+ def resolve_output_path(
990
+ directory: Optional[Union[Path, str]], filename: Optional[str], fallback: str
991
+ ) -> Path:
992
+ if isinstance(directory, Path):
993
+ base = directory
994
+ else:
995
+ base = normalise_output_directory(directory)
996
+ candidate = Path(filename or "").expanduser()
997
+ if str(candidate):
998
+ if candidate.is_absolute():
999
+ return candidate
1000
+ return (base / candidate).resolve()
1001
+ return (base / fallback).resolve()
1002
+
1003
+
1004
+ ARTIFACT_FILE_EXTENSIONS: Tuple[str, ...] = (
1005
+ ".keras",
1006
+ ".h5",
1007
+ ".joblib",
1008
+ ".pkl",
1009
+ ".json",
1010
+ ".onnx",
1011
+ ".zip",
1012
+ ".txt",
1013
+ )
1014
+
1015
+
1016
+ def gather_directory_choices(current: Optional[str]) -> Tuple[List[str], str]:
1017
+ base = normalise_output_directory(current or str(MODEL_OUTPUT_DIR))
1018
+ candidates = {str(base)}
1019
+ try:
1020
+ for candidate in base.parent.iterdir():
1021
+ if candidate.is_dir():
1022
+ candidates.add(str(candidate.resolve()))
1023
+ except Exception:
1024
+ pass
1025
+ return sorted(candidates), str(base)
1026
+
1027
+
1028
+ def gather_artifact_choices(
1029
+ directory: Optional[str], selection: Optional[str] = None
1030
+ ) -> Tuple[List[Tuple[str, str]], Optional[str]]:
1031
+ base = normalise_output_directory(directory)
1032
+ choices: List[Tuple[str, str]] = []
1033
+ selected_value: Optional[str] = None
1034
+ if base.exists():
1035
+ try:
1036
+ artifacts = sorted(
1037
+ [
1038
+ path
1039
+ for path in base.iterdir()
1040
+ if path.is_file()
1041
+ and (
1042
+ not ARTIFACT_FILE_EXTENSIONS
1043
+ or path.suffix.lower() in ARTIFACT_FILE_EXTENSIONS
1044
+ )
1045
+ ],
1046
+ key=lambda path: path.name.lower(),
1047
+ )
1048
+ choices = [(artifact.name, str(artifact)) for artifact in artifacts]
1049
+ except Exception:
1050
+ choices = []
1051
+
1052
+ if selection and any(value == selection for _, value in choices):
1053
+ selected_value = selection
1054
+ elif choices:
1055
+ selected_value = choices[0][1]
1056
+
1057
+ return choices, selected_value
1058
+
1059
+
1060
+ def download_button_state(path: Optional[Union[str, Path]]):
1061
+ if not path:
1062
+ return gr.update(value=None, visible=False)
1063
+ candidate = Path(path)
1064
+ if candidate.exists():
1065
+ return gr.update(value=str(candidate), visible=True)
1066
+ return gr.update(value=None, visible=False)
1067
+
1068
+
1069
+ def clear_training_files():
1070
+ default_label = LABEL_COLUMN or "Fault"
1071
+ for cached_file in TRAINING_UPLOAD_DIR.glob("*"):
1072
+ try:
1073
+ if cached_file.is_file():
1074
+ cached_file.unlink(missing_ok=True)
1075
+ except Exception:
1076
+ pass
1077
+ return (
1078
+ [],
1079
+ "No training files selected.",
1080
+ gr.update(choices=[default_label], value=default_label),
1081
+ gr.update(value=None),
1082
+ )
1083
+
1084
+
1085
+ PROJECT_OVERVIEW_MD = """
1086
+ ## Project Overview
1087
+
1088
+ This project focuses on classifying faults in electrical transmission lines and
1089
+ grid-connected photovoltaic (PV) systems by combining ensemble learning
1090
+ techniques with deep neural architectures.
1091
+
1092
+ ## Datasets
1093
+
1094
+ ### Transmission Line Fault Dataset
1095
+ - 134,406 samples collected from Phasor Measurement Units (PMUs)
1096
+ - 14 monitored channels covering currents, voltages, magnitudes, frequency, and phase angles
1097
+ - Labels span symmetrical and asymmetrical faults: NF, L-G, LL, LL-G, LLL, and LLL-G
1098
+ - Time span: 0 to 5.7 seconds with high-frequency sampling
1099
+
1100
+ ### Grid-Connected PV System Fault Dataset
1101
+ - 2,163,480 samples from 16 experimental scenarios
1102
+ - 14 features including PV array measurements (Ipv, Vpv, Vdc), three-phase currents/voltages, aggregate magnitudes (Iabc, Vabc), and frequency indicators (If, Vf)
1103
+ - Captures array, inverter, grid anomaly, feedback sensor, and MPPT controller faults at 9.9989 μs sampling intervals
1104
+
1105
+ ## Data Format Quick Reference
1106
+
1107
+ Each measurement file may be comma or tab separated and typically exposes the
1108
+ following ordered columns:
1109
+
1110
+ 1. `Timestamp`
1111
+ 2. `[325] UPMU_SUB22:FREQ` – system frequency (Hz)
1112
+ 3. `[326] UPMU_SUB22:DFDT` – frequency rate-of-change
1113
+ 4. `[327] UPMU_SUB22:FLAG` – PMU status flag
1114
+ 5. `[328] UPMU_SUB22-L1:MAG` – phase A voltage magnitude
1115
+ 6. `[329] UPMU_SUB22-L1:ANG` – phase A voltage angle
1116
+ 7. `[330] UPMU_SUB22-L2:MAG` – phase B voltage magnitude
1117
+ 8. `[331] UPMU_SUB22-L2:ANG` – phase B voltage angle
1118
+ 9. `[332] UPMU_SUB22-L3:MAG` – phase C voltage magnitude
1119
+ 10. `[333] UPMU_SUB22-L3:ANG` – phase C voltage angle
1120
+ 11. `[334] UPMU_SUB22-C1:MAG` – phase A current magnitude
1121
+ 12. `[335] UPMU_SUB22-C1:ANG` – phase A current angle
1122
+ 13. `[336] UPMU_SUB22-C2:MAG` – phase B current magnitude
1123
+ 14. `[337] UPMU_SUB22-C2:ANG` – phase B current angle
1124
+ 15. `[338] UPMU_SUB22-C3:MAG` – phase C current magnitude
1125
+ 16. `[339] UPMU_SUB22-C3:ANG` – phase C current angle
1126
+
1127
+ The training tab automatically downloads the latest CSV exports from the
1128
+ `VincentCroft/ThesisModelData` repository and concatenates them before building
1129
+ sliding windows.
1130
+
1131
+ ## Models Developed
1132
+
1133
+ 1. **Support Vector Machine (SVM)** – provides the classical machine learning baseline with balanced accuracy across both datasets (85% PMU / 83% PV).
1134
+ 2. **CNN-LSTM** – couples convolutional feature extraction with temporal memory, achieving 92% PMU / 89% PV accuracy.
1135
+ 3. **Temporal Convolutional Network (TCN)** – leverages dilated convolutions for long-range context and delivers the best trade-off between accuracy and training time (94% PMU / 91% PV).
1136
+
1137
+ ## Results Summary
1138
+
1139
+ - **Transmission Line Fault Classification**: SVM 85%, CNN-LSTM 92%, TCN 94%
1140
+ - **PV System Fault Classification**: SVM 83%, CNN-LSTM 89%, TCN 91%
1141
+
1142
+ Use the **Inference** tab to score new PMU/PV windows and the **Training** tab to
1143
+ fine-tune or retrain any of the supported models directly within Hugging Face
1144
+ Spaces. The logs panel will surface TensorBoard archives whenever deep-learning
1145
+ models are trained.
1146
+ """
1147
+
1148
+
1149
+ def load_measurement_csv(path: str) -> pd.DataFrame:
1150
+ """Read a PMU/PV measurement file with flexible separators and column mapping."""
1151
+
1152
+ try:
1153
+ df = pd.read_csv(path, sep=None, engine="python", encoding="utf-8-sig")
1154
+ except Exception:
1155
+ df = None
1156
+ for separator in ("\t", ",", ";"):
1157
+ try:
1158
+ df = pd.read_csv(
1159
+ path, sep=separator, engine="python", encoding="utf-8-sig"
1160
+ )
1161
+ break
1162
+ except Exception:
1163
+ df = None
1164
+ if df is None:
1165
+ raise
1166
+
1167
+ # Clean column names
1168
+ df.columns = [str(col).strip() for col in df.columns]
1169
+
1170
+ print(f"Loaded CSV with {len(df)} rows and {len(df.columns)} columns")
1171
+ print(f"Columns: {list(df.columns)}")
1172
+ print(f"Data shape: {df.shape}")
1173
+
1174
+ # Check if we have enough data for training
1175
+ if len(df) < 100:
1176
+ print(
1177
+ f"Warning: Only {len(df)} rows of data. Recommend at least 1000 rows for effective training."
1178
+ )
1179
+
1180
+ # Check for label column
1181
+ has_label = any(
1182
+ col.lower() in ["fault", "label", "class", "target"] for col in df.columns
1183
+ )
1184
+ if not has_label:
1185
+ print(
1186
+ "Warning: No label column found. Adding dummy 'Fault' column with value 'Normal' for all samples."
1187
+ )
1188
+ df["Fault"] = "Normal" # Add dummy label for training
1189
+
1190
+ # Create column mapping - map similar column names to expected format
1191
+ column_mapping = {}
1192
+ expected_cols = list(REQUIRED_PMU_COLUMNS)
1193
+
1194
+ # If we have at least the right number of numeric columns after Timestamp, use positional mapping
1195
+ if "Timestamp" in df.columns:
1196
+ numeric_cols = [col for col in df.columns if col != "Timestamp"]
1197
+ if len(numeric_cols) >= len(expected_cols):
1198
+ # Map by position (after Timestamp)
1199
+ for i, expected_col in enumerate(expected_cols):
1200
+ if i < len(numeric_cols):
1201
+ column_mapping[numeric_cols[i]] = expected_col
1202
+
1203
+ # Rename columns to match expected format
1204
+ df = df.rename(columns=column_mapping)
1205
+
1206
+ # Check if we have the required columns after mapping
1207
+ missing = [col for col in REQUIRED_PMU_COLUMNS if col not in df.columns]
1208
+ if missing:
1209
+ # If still missing, try a more flexible approach
1210
+ available_numeric = df.select_dtypes(include=[np.number]).columns.tolist()
1211
+ if len(available_numeric) >= len(expected_cols):
1212
+ # Use the first N numeric columns
1213
+ for i, expected_col in enumerate(expected_cols):
1214
+ if i < len(available_numeric):
1215
+ if available_numeric[i] not in df.columns:
1216
+ continue
1217
+ df = df.rename(columns={available_numeric[i]: expected_col})
1218
+
1219
+ # Recheck missing columns
1220
+ missing = [col for col in REQUIRED_PMU_COLUMNS if col not in df.columns]
1221
+
1222
+ if missing:
1223
+ missing_str = ", ".join(missing)
1224
+ available_str = ", ".join(df.columns.tolist())
1225
+ raise ValueError(
1226
+ f"Missing required PMU feature columns: {missing_str}. "
1227
+ f"Available columns: {available_str}. "
1228
+ "Please ensure your CSV has the correct format with Timestamp followed by PMU measurements."
1229
+ )
1230
+
1231
+ return df
1232
+
1233
+
1234
+ def apply_metadata(metadata: Dict[str, Any]) -> None:
1235
+ global FEATURE_COLUMNS, LABEL_CLASSES, LABEL_COLUMN, SEQUENCE_LENGTH, DEFAULT_WINDOW_STRIDE, MODEL_TYPE, MODEL_FORMAT
1236
+ FEATURE_COLUMNS = [
1237
+ str(col) for col in metadata.get("feature_columns", DEFAULT_FEATURE_COLUMNS)
1238
+ ]
1239
+ LABEL_CLASSES = [str(label) for label in metadata.get("label_classes", [])]
1240
+ LABEL_COLUMN = str(metadata.get("label_column", "Fault"))
1241
+ SEQUENCE_LENGTH = int(metadata.get("sequence_length", DEFAULT_SEQUENCE_LENGTH))
1242
+ DEFAULT_WINDOW_STRIDE = int(metadata.get("stride", DEFAULT_STRIDE))
1243
+ MODEL_TYPE = str(metadata.get("model_type", "cnn_lstm")).lower()
1244
+ MODEL_FORMAT = str(
1245
+ metadata.get("model_format", "joblib" if MODEL_TYPE == "svm" else "keras")
1246
+ ).lower()
1247
+
1248
+
1249
+ apply_metadata(METADATA)
1250
+
1251
+
1252
+ def sync_label_classes_from_model(model: Optional[object]) -> None:
1253
+ global LABEL_CLASSES
1254
+ if model is None:
1255
+ return
1256
+ if hasattr(model, "classes_"):
1257
+ LABEL_CLASSES = [str(label) for label in getattr(model, "classes_")]
1258
+ elif not LABEL_CLASSES and hasattr(model, "output_shape"):
1259
+ LABEL_CLASSES = [str(i) for i in range(int(model.output_shape[-1]))]
1260
+
1261
+
1262
+ # Load model and scaler with error handling
1263
+ print("Loading model and scaler...")
1264
+ try:
1265
+ MODEL = try_load_model(MODEL_PATH, MODEL_TYPE, MODEL_FORMAT)
1266
+ print(f"Model loaded: {MODEL is not None}")
1267
+ except Exception as e:
1268
+ print(f"Model loading failed: {e}")
1269
+ MODEL = None
1270
+
1271
+ try:
1272
+ SCALER = try_load_scaler(SCALER_PATH)
1273
+ print(f"Scaler loaded: {SCALER is not None}")
1274
+ except Exception as e:
1275
+ print(f"Scaler loading failed: {e}")
1276
+ SCALER = None
1277
+
1278
+ try:
1279
+ sync_label_classes_from_model(MODEL)
1280
+ print("Label classes synchronized")
1281
+ except Exception as e:
1282
+ print(f"Label sync failed: {e}")
1283
+
1284
+ print("Application initialization completed.")
1285
+ print(
1286
+ f"Ready to start Gradio interface. Model available: {MODEL is not None}, Scaler available: {SCALER is not None}"
1287
+ )
1288
+
1289
+
1290
+ def refresh_artifacts(model_path: Path, scaler_path: Path, metadata_path: Path) -> None:
1291
+ global MODEL_PATH, SCALER_PATH, METADATA_PATH, MODEL, SCALER, METADATA
1292
+ MODEL_PATH = model_path
1293
+ SCALER_PATH = scaler_path
1294
+ METADATA_PATH = metadata_path
1295
+ METADATA = load_metadata(metadata_path)
1296
+ apply_metadata(METADATA)
1297
+ MODEL = try_load_model(model_path, MODEL_TYPE, MODEL_FORMAT)
1298
+ SCALER = try_load_scaler(scaler_path)
1299
+ sync_label_classes_from_model(MODEL)
1300
+
1301
+
1302
+ # --------------------------------------------------------------------------------------
1303
+ # Pre-processing helpers
1304
+ # --------------------------------------------------------------------------------------
1305
+
1306
+
1307
+ def ensure_ready():
1308
+ if MODEL is None or SCALER is None:
1309
+ raise RuntimeError(
1310
+ "The model and feature scaler are not available. Upload the trained model "
1311
+ "(for example `pmu_cnn_lstm_model.keras`, `pmu_tcn_model.keras`, or `pmu_svm_model.joblib`), "
1312
+ "the feature scaler (`pmu_feature_scaler.pkl`), and the metadata JSON (`pmu_metadata.json`) to the Space root "
1313
+ "or configure the Hugging Face Hub environment variables so the artifacts can be downloaded "
1314
+ "automatically."
1315
+ )
1316
+
1317
+
1318
+ def parse_text_features(text: str) -> np.ndarray:
1319
+ cleaned = re.sub(r"[;\n\t]+", ",", text.strip())
1320
+ arr = np.fromstring(cleaned, sep=",")
1321
+ if arr.size == 0:
1322
+ raise ValueError(
1323
+ "No feature values were parsed. Please enter comma-separated numbers."
1324
+ )
1325
+ return arr.astype(np.float32)
1326
+
1327
+
1328
+ def apply_scaler(sequences: np.ndarray) -> np.ndarray:
1329
+ if SCALER is None:
1330
+ return sequences
1331
+ shape = sequences.shape
1332
+ flattened = sequences.reshape(-1, shape[-1])
1333
+ scaled = SCALER.transform(flattened)
1334
+ return scaled.reshape(shape)
1335
+
1336
+
1337
+ def make_sliding_windows(
1338
+ data: np.ndarray, sequence_length: int, stride: int
1339
+ ) -> np.ndarray:
1340
+ if data.shape[0] < sequence_length:
1341
+ raise ValueError(
1342
+ f"The dataset contains {data.shape[0]} rows which is less than the requested sequence "
1343
+ f"length {sequence_length}. Provide more samples or reduce the sequence length."
1344
+ )
1345
+ windows = [
1346
+ data[start : start + sequence_length]
1347
+ for start in range(0, data.shape[0] - sequence_length + 1, stride)
1348
+ ]
1349
+ return np.stack(windows)
1350
+
1351
+
1352
+ def dataframe_to_sequences(
1353
+ df: pd.DataFrame,
1354
+ *,
1355
+ sequence_length: int,
1356
+ stride: int,
1357
+ feature_columns: Sequence[str],
1358
+ drop_label: bool = True,
1359
+ ) -> np.ndarray:
1360
+ work_df = df.copy()
1361
+ if drop_label and LABEL_COLUMN in work_df.columns:
1362
+ work_df = work_df.drop(columns=[LABEL_COLUMN])
1363
+ if "Timestamp" in work_df.columns:
1364
+ work_df = work_df.sort_values("Timestamp")
1365
+
1366
+ available_cols = [c for c in feature_columns if c in work_df.columns]
1367
+ n_features = len(feature_columns)
1368
+ if available_cols and len(available_cols) == n_features:
1369
+ array = work_df[available_cols].astype(np.float32).to_numpy()
1370
+ return make_sliding_windows(array, sequence_length, stride)
1371
+
1372
+ numeric_df = work_df.select_dtypes(include=[np.number])
1373
+ array = numeric_df.astype(np.float32).to_numpy()
1374
+ if array.shape[1] == n_features * sequence_length:
1375
+ return array.reshape(array.shape[0], sequence_length, n_features)
1376
+ if sequence_length == 1 and array.shape[1] == n_features:
1377
+ return array.reshape(array.shape[0], 1, n_features)
1378
+ raise ValueError(
1379
+ "CSV columns do not match the expected feature layout. Include the full PMU feature set "
1380
+ "or provide pre-shaped sliding window data."
1381
+ )
1382
+
1383
+
1384
+ def label_name(index: int) -> str:
1385
+ if 0 <= index < len(LABEL_CLASSES):
1386
+ return str(LABEL_CLASSES[index])
1387
+ return f"class_{index}"
1388
+
1389
+
1390
+ def format_predictions(probabilities: np.ndarray) -> pd.DataFrame:
1391
+ rows: List[Dict[str, object]] = []
1392
+ order = np.argsort(probabilities, axis=1)[:, ::-1]
1393
+ for idx, (prob_row, ranking) in enumerate(zip(probabilities, order)):
1394
+ top_idx = int(ranking[0])
1395
+ top_label = label_name(top_idx)
1396
+ top_conf = float(prob_row[top_idx])
1397
+ top3 = [f"{label_name(i)} ({prob_row[i]*100:.2f}%)" for i in ranking[:3]]
1398
+ rows.append(
1399
+ {
1400
+ "window": idx,
1401
+ "predicted_label": top_label,
1402
+ "confidence": round(top_conf, 4),
1403
+ "top3": " | ".join(top3),
1404
+ }
1405
+ )
1406
+ return pd.DataFrame(rows)
1407
+
1408
+
1409
+ def probabilities_to_json(probabilities: np.ndarray) -> List[Dict[str, object]]:
1410
+ payload: List[Dict[str, object]] = []
1411
+ for idx, prob_row in enumerate(probabilities):
1412
+ payload.append(
1413
+ {
1414
+ "window": int(idx),
1415
+ "probabilities": {
1416
+ label_name(i): float(prob_row[i]) for i in range(prob_row.shape[0])
1417
+ },
1418
+ }
1419
+ )
1420
+ return payload
1421
+
1422
+
1423
+ def predict_sequences(
1424
+ sequences: np.ndarray,
1425
+ ) -> Tuple[str, pd.DataFrame, List[Dict[str, object]]]:
1426
+ ensure_ready()
1427
+ sequences = apply_scaler(sequences.astype(np.float32))
1428
+ if MODEL_TYPE == "svm":
1429
+ flattened = sequences.reshape(sequences.shape[0], -1)
1430
+ if hasattr(MODEL, "predict_proba"):
1431
+ probs = MODEL.predict_proba(flattened)
1432
+ else:
1433
+ raise RuntimeError(
1434
+ "Loaded SVM model does not expose predict_proba. Retrain with probability=True."
1435
+ )
1436
+ else:
1437
+ probs = MODEL.predict(sequences, verbose=0)
1438
+ table = format_predictions(probs)
1439
+ json_probs = probabilities_to_json(probs)
1440
+ architecture = MODEL_TYPE.replace("_", "-").upper()
1441
+ status = f"Generated {len(sequences)} windows. {architecture} model output dimension: {probs.shape[1]}."
1442
+ return status, table, json_probs
1443
+
1444
+
1445
+ def predict_from_text(
1446
+ text: str, sequence_length: int
1447
+ ) -> Tuple[str, pd.DataFrame, List[Dict[str, object]]]:
1448
+ arr = parse_text_features(text)
1449
+ n_features = len(FEATURE_COLUMNS)
1450
+ if arr.size % n_features != 0:
1451
+ raise ValueError(
1452
+ f"The number of values ({arr.size}) is not a multiple of the feature dimension "
1453
+ f"({n_features}). Provide values in groups of {n_features}."
1454
+ )
1455
+ timesteps = arr.size // n_features
1456
+ if timesteps != sequence_length:
1457
+ raise ValueError(
1458
+ f"Detected {timesteps} timesteps which does not match the configured sequence length "
1459
+ f"({sequence_length})."
1460
+ )
1461
+ sequences = arr.reshape(1, sequence_length, n_features)
1462
+ status, table, probs = predict_sequences(sequences)
1463
+ status = f"Single window prediction complete. {status}"
1464
+ return status, table, probs
1465
+
1466
+
1467
+ def predict_from_csv(
1468
+ file_obj, sequence_length: int, stride: int
1469
+ ) -> Tuple[str, pd.DataFrame, List[Dict[str, object]]]:
1470
+ df = load_measurement_csv(file_obj.name)
1471
+ sequences = dataframe_to_sequences(
1472
+ df,
1473
+ sequence_length=sequence_length,
1474
+ stride=stride,
1475
+ feature_columns=FEATURE_COLUMNS,
1476
+ )
1477
+ status, table, probs = predict_sequences(sequences)
1478
+ status = f"CSV processed successfully. Generated {len(sequences)} windows. {status}"
1479
+ return status, table, probs
1480
+
1481
+
1482
+ # --------------------------------------------------------------------------------------
1483
+ # Training helpers
1484
+ # --------------------------------------------------------------------------------------
1485
+
1486
+
1487
+ def classification_report_to_dataframe(report: Dict[str, Any]) -> pd.DataFrame:
1488
+ rows: List[Dict[str, Any]] = []
1489
+ for label, metrics in report.items():
1490
+ if isinstance(metrics, dict):
1491
+ row = {"label": label}
1492
+ for key, value in metrics.items():
1493
+ if key == "support":
1494
+ row[key] = int(value)
1495
+ else:
1496
+ row[key] = round(float(value), 4)
1497
+ rows.append(row)
1498
+ else:
1499
+ rows.append({"label": label, "accuracy": round(float(metrics), 4)})
1500
+ return pd.DataFrame(rows)
1501
+
1502
+
1503
+ def confusion_matrix_to_dataframe(
1504
+ confusion: Sequence[Sequence[float]], labels: Sequence[str]
1505
+ ) -> pd.DataFrame:
1506
+ if not confusion:
1507
+ return pd.DataFrame()
1508
+ df = pd.DataFrame(confusion, index=list(labels), columns=list(labels))
1509
+ df.index.name = "True Label"
1510
+ df.columns.name = "Predicted Label"
1511
+ return df
1512
+
1513
+
1514
+ # --------------------------------------------------------------------------------------
1515
+ # Gradio interface
1516
+ # --------------------------------------------------------------------------------------
1517
+
1518
+
1519
+ def build_interface() -> gr.Blocks:
1520
+ theme = gr.themes.Soft(
1521
+ primary_hue="sky", secondary_hue="blue", neutral_hue="gray"
1522
+ ).set(
1523
+ body_background_fill="#1f1f1f",
1524
+ body_text_color="#f5f5f5",
1525
+ block_background_fill="#262626",
1526
+ block_border_color="#333333",
1527
+ button_primary_background_fill="#5ac8fa",
1528
+ button_primary_background_fill_hover="#48b5eb",
1529
+ button_primary_border_color="#38bdf8",
1530
+ button_primary_text_color="#0f172a",
1531
+ button_secondary_background_fill="#3f3f46",
1532
+ button_secondary_text_color="#f5f5f5",
1533
+ )
1534
+
1535
+ def _normalise_directory_string(value: Optional[Union[str, Path]]) -> str:
1536
+ if value is None:
1537
+ return ""
1538
+ path = Path(value).expanduser()
1539
+ try:
1540
+ return str(path.resolve())
1541
+ except Exception:
1542
+ return str(path)
1543
+
1544
+ with gr.Blocks(
1545
+ title="Fault Classification - PMU Data", theme=theme, css=APP_CSS
1546
+ ) as demo:
1547
+ gr.Markdown("# Fault Classification for PMU & PV Data")
1548
+ gr.Markdown(
1549
+ "🖥️ TensorFlow is locked to CPU execution so the Space can run without CUDA drivers."
1550
+ )
1551
+ if MODEL is None or SCALER is None:
1552
+ gr.Markdown(
1553
+ "⚠️ **Artifacts Missing** — Upload `pmu_cnn_lstm_model.keras`, "
1554
+ "`pmu_feature_scaler.pkl`, and `pmu_metadata.json` to enable inference, "
1555
+ "or configure the Hugging Face Hub environment variables so they can be downloaded."
1556
+ )
1557
+ else:
1558
+ class_count = len(LABEL_CLASSES) if LABEL_CLASSES else "unknown"
1559
+ gr.Markdown(
1560
+ f"Loaded a **{MODEL_TYPE.upper()}** model ({MODEL_FORMAT.upper()}) with "
1561
+ f"{len(FEATURE_COLUMNS)} features, sequence length **{SEQUENCE_LENGTH}**, and "
1562
+ f"{class_count} target classes. Use the tabs below to run inference or fine-tune "
1563
+ "the model with your own CSV files."
1564
+ )
1565
+
1566
+ with gr.Accordion("Feature Reference", open=False):
1567
+ gr.Markdown(
1568
+ f"Each time window expects **{len(FEATURE_COLUMNS)} features** ordered as follows:\n"
1569
+ + "\n".join(f"- {name}" for name in FEATURE_COLUMNS)
1570
+ )
1571
+ gr.Markdown(
1572
+ f"Default training parameters: **sequence length = {SEQUENCE_LENGTH}**, "
1573
+ f"**stride = {DEFAULT_WINDOW_STRIDE}**. Adjust them in the tabs as needed."
1574
+ )
1575
+
1576
+ with gr.Tabs():
1577
+ with gr.Tab("Overview"):
1578
+ gr.Markdown(PROJECT_OVERVIEW_MD)
1579
+ with gr.Tab("Inference"):
1580
+ gr.Markdown("## Run Inference")
1581
+ with gr.Row():
1582
+ file_in = gr.File(label="Upload PMU CSV", file_types=[".csv"])
1583
+ text_in = gr.Textbox(
1584
+ lines=4,
1585
+ label="Or paste a single window (comma separated)",
1586
+ placeholder="49.97772,1.215825E-38,...",
1587
+ )
1588
+
1589
+ with gr.Row():
1590
+ sequence_length_input = gr.Slider(
1591
+ minimum=1,
1592
+ maximum=max(1, SEQUENCE_LENGTH * 2),
1593
+ step=1,
1594
+ value=SEQUENCE_LENGTH,
1595
+ label="Sequence length (timesteps)",
1596
+ )
1597
+ stride_input = gr.Slider(
1598
+ minimum=1,
1599
+ maximum=max(1, SEQUENCE_LENGTH),
1600
+ step=1,
1601
+ value=max(1, DEFAULT_WINDOW_STRIDE),
1602
+ label="CSV window stride",
1603
+ )
1604
+
1605
+ predict_btn = gr.Button("🚀 Run Inference", variant="primary")
1606
+ status_out = gr.Textbox(label="Status", interactive=False)
1607
+ table_out = gr.Dataframe(
1608
+ headers=["window", "predicted_label", "confidence", "top3"],
1609
+ label="Predictions",
1610
+ interactive=False,
1611
+ )
1612
+ probs_out = gr.JSON(label="Per-window probabilities")
1613
+
1614
+ def _run_prediction(file_obj, text, sequence_length, stride):
1615
+ sequence_length = int(sequence_length)
1616
+ stride = int(stride)
1617
+ try:
1618
+ if file_obj is not None:
1619
+ return predict_from_csv(file_obj, sequence_length, stride)
1620
+ if text and text.strip():
1621
+ return predict_from_text(text, sequence_length)
1622
+ return (
1623
+ "Please upload a CSV file or provide feature values.",
1624
+ pd.DataFrame(),
1625
+ [],
1626
+ )
1627
+ except Exception as exc:
1628
+ return f"Prediction failed: {exc}", pd.DataFrame(), []
1629
+
1630
+ predict_btn.click(
1631
+ _run_prediction,
1632
+ inputs=[file_in, text_in, sequence_length_input, stride_input],
1633
+ outputs=[status_out, table_out, probs_out],
1634
+ concurrency_limit=EVENT_CONCURRENCY_LIMIT,
1635
+ )
1636
+
1637
+ with gr.Tab("Training"):
1638
+ gr.Markdown("## Train or Fine-tune the Model")
1639
+ gr.Markdown(
1640
+ "Training data is automatically downloaded from the database. "
1641
+ "Refresh the cache if new files are added upstream."
1642
+ )
1643
+
1644
+ training_files_state = gr.State([])
1645
+ with gr.Row():
1646
+ with gr.Column(scale=3):
1647
+ training_files_summary = gr.Textbox(
1648
+ label="Database training CSVs",
1649
+ value="Training dataset not loaded yet.",
1650
+ lines=4,
1651
+ interactive=False,
1652
+ elem_id="training-files-summary",
1653
+ )
1654
+ with gr.Column(scale=2, min_width=240):
1655
+ dataset_info = gr.Markdown(
1656
+ "No local database CSVs downloaded yet.",
1657
+ )
1658
+ dataset_refresh = gr.Button(
1659
+ "🔄 Reload dataset from database",
1660
+ variant="secondary",
1661
+ )
1662
+ clear_cache_button = gr.Button(
1663
+ "🧹 Clear downloaded cache",
1664
+ variant="secondary",
1665
+ )
1666
+
1667
+ with gr.Accordion("📂 DataBaseBrowser", open=False):
1668
+ gr.Markdown(
1669
+ "Browse the upstream database by date and download only the CSVs you need."
1670
+ )
1671
+ with gr.Row(elem_id="date-browser-row"):
1672
+ with gr.Column(scale=1, elem_classes=["date-browser-column"]):
1673
+ year_selector = gr.Dropdown(label="Year", choices=[])
1674
+ year_download_button = gr.Button(
1675
+ "⬇️ Download year CSVs", variant="secondary"
1676
+ )
1677
+ with gr.Column(scale=1, elem_classes=["date-browser-column"]):
1678
+ month_selector = gr.Dropdown(label="Month", choices=[])
1679
+ month_download_button = gr.Button(
1680
+ "⬇️ Download month CSVs", variant="secondary"
1681
+ )
1682
+ with gr.Column(scale=1, elem_classes=["date-browser-column"]):
1683
+ day_selector = gr.Dropdown(label="Day", choices=[])
1684
+ day_download_button = gr.Button(
1685
+ "⬇️ Download day CSVs", variant="secondary"
1686
+ )
1687
+ with gr.Column(elem_id="available-files-section"):
1688
+ available_files = gr.CheckboxGroup(
1689
+ label="Available CSV files",
1690
+ choices=[],
1691
+ value=[],
1692
+ elem_id="available-files-grid",
1693
+ )
1694
+ download_button = gr.Button(
1695
+ "⬇️ Download selected CSVs",
1696
+ variant="secondary",
1697
+ elem_id="download-selected-button",
1698
+ )
1699
+ repo_status = gr.Markdown(
1700
+ "Click 'Reload dataset from database' to fetch the directory tree."
1701
+ )
1702
+
1703
+ with gr.Row():
1704
+ label_input = gr.Dropdown(
1705
+ value=LABEL_COLUMN,
1706
+ choices=[LABEL_COLUMN],
1707
+ allow_custom_value=True,
1708
+ label="Label column name",
1709
+ )
1710
+ model_selector = gr.Radio(
1711
+ choices=["CNN-LSTM", "TCN", "SVM"],
1712
+ value=(
1713
+ "TCN"
1714
+ if MODEL_TYPE == "tcn"
1715
+ else ("SVM" if MODEL_TYPE == "svm" else "CNN-LSTM")
1716
+ ),
1717
+ label="Model architecture",
1718
+ )
1719
+ sequence_length_train = gr.Slider(
1720
+ minimum=4,
1721
+ maximum=max(32, SEQUENCE_LENGTH * 2),
1722
+ step=1,
1723
+ value=SEQUENCE_LENGTH,
1724
+ label="Sequence length",
1725
+ )
1726
+ stride_train = gr.Slider(
1727
+ minimum=1,
1728
+ maximum=max(32, SEQUENCE_LENGTH * 2),
1729
+ step=1,
1730
+ value=max(1, DEFAULT_WINDOW_STRIDE),
1731
+ label="Stride",
1732
+ )
1733
+
1734
+ model_default = MODEL_FILENAME_BY_TYPE.get(
1735
+ MODEL_TYPE, Path(LOCAL_MODEL_FILE).name
1736
+ )
1737
+
1738
+ with gr.Row():
1739
+ validation_train = gr.Slider(
1740
+ minimum=0.05,
1741
+ maximum=0.4,
1742
+ step=0.05,
1743
+ value=0.2,
1744
+ label="Validation split",
1745
+ )
1746
+ batch_train = gr.Slider(
1747
+ minimum=32,
1748
+ maximum=512,
1749
+ step=32,
1750
+ value=128,
1751
+ label="Batch size",
1752
+ )
1753
+ epochs_train = gr.Slider(
1754
+ minimum=5,
1755
+ maximum=100,
1756
+ step=5,
1757
+ value=50,
1758
+ label="Epochs",
1759
+ )
1760
+
1761
+ directory_choices, directory_default = gather_directory_choices(
1762
+ str(MODEL_OUTPUT_DIR)
1763
+ )
1764
+ artifact_choices, default_artifact = gather_artifact_choices(
1765
+ directory_default
1766
+ )
1767
+
1768
+ with gr.Row():
1769
+ output_directory = gr.Dropdown(
1770
+ value=directory_default,
1771
+ label="Output directory",
1772
+ choices=directory_choices,
1773
+ allow_custom_value=True,
1774
+ )
1775
+ model_name = gr.Textbox(
1776
+ value=model_default,
1777
+ label="Model output filename",
1778
+ )
1779
+ scaler_name = gr.Textbox(
1780
+ value=Path(LOCAL_SCALER_FILE).name,
1781
+ label="Scaler output filename",
1782
+ )
1783
+ metadata_name = gr.Textbox(
1784
+ value=Path(LOCAL_METADATA_FILE).name,
1785
+ label="Metadata output filename",
1786
+ )
1787
+
1788
+ with gr.Row():
1789
+ artifact_browser = gr.Dropdown(
1790
+ label="Saved artifacts in directory",
1791
+ choices=artifact_choices,
1792
+ value=default_artifact,
1793
+ )
1794
+ artifact_download_button = gr.DownloadButton(
1795
+ "⬇️ Download selected artifact",
1796
+ value=default_artifact,
1797
+ visible=bool(default_artifact),
1798
+ variant="secondary",
1799
+ )
1800
+
1801
+ def on_output_directory_change(selected_dir, current_selection):
1802
+ choices, normalised = gather_directory_choices(selected_dir)
1803
+ artifact_options, selected = gather_artifact_choices(
1804
+ normalised, current_selection
1805
+ )
1806
+ return (
1807
+ gr.update(choices=choices, value=normalised),
1808
+ gr.update(choices=artifact_options, value=selected),
1809
+ download_button_state(selected),
1810
+ )
1811
+
1812
+ def on_artifact_change(selected_path):
1813
+ return download_button_state(selected_path)
1814
+
1815
+ output_directory.change(
1816
+ on_output_directory_change,
1817
+ inputs=[output_directory, artifact_browser],
1818
+ outputs=[
1819
+ output_directory,
1820
+ artifact_browser,
1821
+ artifact_download_button,
1822
+ ],
1823
+ concurrency_limit=EVENT_CONCURRENCY_LIMIT,
1824
+ )
1825
+
1826
+ artifact_browser.change(
1827
+ on_artifact_change,
1828
+ inputs=[artifact_browser],
1829
+ outputs=[artifact_download_button],
1830
+ concurrency_limit=EVENT_CONCURRENCY_LIMIT,
1831
+ )
1832
+
1833
+ with gr.Row(elem_id="artifact-download-row"):
1834
+ model_download_button = gr.DownloadButton(
1835
+ "⬇️ Download model file",
1836
+ value=None,
1837
+ visible=False,
1838
+ elem_classes=["artifact-download-button"],
1839
+ )
1840
+ scaler_download_button = gr.DownloadButton(
1841
+ "⬇️ Download scaler file",
1842
+ value=None,
1843
+ visible=False,
1844
+ elem_classes=["artifact-download-button"],
1845
+ )
1846
+ metadata_download_button = gr.DownloadButton(
1847
+ "⬇️ Download metadata file",
1848
+ value=None,
1849
+ visible=False,
1850
+ elem_classes=["artifact-download-button"],
1851
+ )
1852
+ tensorboard_download_button = gr.DownloadButton(
1853
+ "⬇️ Download TensorBoard logs",
1854
+ value=None,
1855
+ visible=False,
1856
+ elem_classes=["artifact-download-button"],
1857
+ )
1858
+
1859
+ model_download_button.file_name = Path(LOCAL_MODEL_FILE).name
1860
+ scaler_download_button.file_name = Path(LOCAL_SCALER_FILE).name
1861
+ metadata_download_button.file_name = Path(LOCAL_METADATA_FILE).name
1862
+ tensorboard_download_button.file_name = "tensorboard_logs.zip"
1863
+
1864
+ tensorboard_toggle = gr.Checkbox(
1865
+ value=True,
1866
+ label="Enable TensorBoard logging (creates downloadable archive)",
1867
+ )
1868
+
1869
+ def _suggest_model_filename(choice: str, current_value: str):
1870
+ choice_key = (choice or "cnn_lstm").lower().replace("-", "_")
1871
+ suggested = MODEL_FILENAME_BY_TYPE.get(
1872
+ choice_key, Path(LOCAL_MODEL_FILE).name
1873
+ )
1874
+ known_defaults = set(MODEL_FILENAME_BY_TYPE.values())
1875
+ current_name = Path(current_value).name if current_value else ""
1876
+ if current_name and current_name not in known_defaults:
1877
+ return gr.update()
1878
+ return gr.update(value=suggested)
1879
+
1880
+ model_selector.change(
1881
+ _suggest_model_filename,
1882
+ inputs=[model_selector, model_name],
1883
+ outputs=model_name,
1884
+ )
1885
+
1886
+ with gr.Row():
1887
+ train_button = gr.Button("🛠️ Start Training", variant="primary")
1888
+ progress_button = gr.Button(
1889
+ "📊 Check Progress", variant="secondary"
1890
+ )
1891
+
1892
+ # Training status display
1893
+ training_status = gr.Textbox(label="Training Status", interactive=False)
1894
+ report_output = gr.Dataframe(
1895
+ label="Classification report", interactive=False
1896
+ )
1897
+ history_output = gr.JSON(label="Training history")
1898
+ confusion_output = gr.Dataframe(
1899
+ label="Confusion matrix", interactive=False
1900
+ )
1901
+
1902
+ # Message area at the bottom for progress updates
1903
+ with gr.Accordion("📋 Progress Messages", open=True):
1904
+ progress_messages = gr.Textbox(
1905
+ label="Training Messages",
1906
+ lines=8,
1907
+ max_lines=20,
1908
+ interactive=False,
1909
+ autoscroll=True,
1910
+ placeholder="Click 'Check Progress' to see training updates...",
1911
+ )
1912
+ with gr.Row():
1913
+ gr.Button("🗑️ Clear Messages", variant="secondary").click(
1914
+ lambda: "", outputs=[progress_messages]
1915
+ )
1916
+
1917
+ def _run_training(
1918
+ file_paths,
1919
+ label_column,
1920
+ model_choice,
1921
+ sequence_length,
1922
+ stride,
1923
+ validation_split,
1924
+ batch_size,
1925
+ epochs,
1926
+ output_dir,
1927
+ model_filename,
1928
+ scaler_filename,
1929
+ metadata_filename,
1930
+ enable_tensorboard,
1931
+ ):
1932
+ base_dir = normalise_output_directory(output_dir)
1933
+ try:
1934
+ base_dir.mkdir(parents=True, exist_ok=True)
1935
+
1936
+ model_path = resolve_output_path(
1937
+ base_dir,
1938
+ model_filename,
1939
+ Path(LOCAL_MODEL_FILE).name,
1940
+ )
1941
+ scaler_path = resolve_output_path(
1942
+ base_dir,
1943
+ scaler_filename,
1944
+ Path(LOCAL_SCALER_FILE).name,
1945
+ )
1946
+ metadata_path = resolve_output_path(
1947
+ base_dir,
1948
+ metadata_filename,
1949
+ Path(LOCAL_METADATA_FILE).name,
1950
+ )
1951
+
1952
+ model_path.parent.mkdir(parents=True, exist_ok=True)
1953
+ scaler_path.parent.mkdir(parents=True, exist_ok=True)
1954
+ metadata_path.parent.mkdir(parents=True, exist_ok=True)
1955
+
1956
+ # Create status file path for progress tracking
1957
+ status_file = model_path.parent / "training_status.txt"
1958
+
1959
+ # Initialize status
1960
+ with open(status_file, "w") as f:
1961
+ f.write("Starting training setup...")
1962
+
1963
+ if not file_paths:
1964
+ raise ValueError(
1965
+ "No training CSVs were found in the database cache. "
1966
+ "Use 'Reload dataset from database' and try again."
1967
+ )
1968
+
1969
+ with open(status_file, "w") as f:
1970
+ f.write("Loading and validating CSV files...")
1971
+
1972
+ available_paths = [
1973
+ path for path in file_paths if Path(path).exists()
1974
+ ]
1975
+ missing_paths = [
1976
+ Path(path).name
1977
+ for path in file_paths
1978
+ if not Path(path).exists()
1979
+ ]
1980
+ if not available_paths:
1981
+ raise ValueError(
1982
+ "Database training dataset is unavailable. Reload the dataset and retry."
1983
+ )
1984
+
1985
+ dfs = [load_measurement_csv(path) for path in available_paths]
1986
+ combined = pd.concat(dfs, ignore_index=True)
1987
+
1988
+ # Validate data size and provide recommendations
1989
+ total_samples = len(combined)
1990
+ if total_samples < 100:
1991
+ print(
1992
+ f"Warning: Only {total_samples} samples. Recommend at least 1000 for good results."
1993
+ )
1994
+ print(
1995
+ "Automatically switching to SVM for small dataset compatibility."
1996
+ )
1997
+ if model_choice in ["cnn_lstm", "tcn"]:
1998
+ model_choice = "svm"
1999
+ print(
2000
+ f"Model type changed to SVM for better small dataset performance."
2001
+ )
2002
+ if total_samples < 10:
2003
+ raise ValueError(
2004
+ f"Insufficient data: {total_samples} samples. Need at least 10 samples for training."
2005
+ )
2006
+
2007
+ label_column = (label_column or LABEL_COLUMN).strip()
2008
+ if not label_column:
2009
+ raise ValueError("Label column name cannot be empty.")
2010
+
2011
+ model_choice = (
2012
+ (model_choice or "CNN-LSTM").lower().replace("-", "_")
2013
+ )
2014
+ if model_choice not in {"cnn_lstm", "tcn", "svm"}:
2015
+ raise ValueError(
2016
+ "Select CNN-LSTM, TCN, or SVM for the model architecture."
2017
+ )
2018
+
2019
+ with open(status_file, "w") as f:
2020
+ f.write(
2021
+ f"Starting {model_choice.upper()} training with {len(combined)} samples..."
2022
+ )
2023
+
2024
+ # Start training
2025
+ result = train_from_dataframe(
2026
+ combined,
2027
+ label_column=label_column,
2028
+ feature_columns=None,
2029
+ sequence_length=int(sequence_length),
2030
+ stride=int(stride),
2031
+ validation_split=float(validation_split),
2032
+ batch_size=int(batch_size),
2033
+ epochs=int(epochs),
2034
+ model_type=model_choice,
2035
+ model_path=model_path,
2036
+ scaler_path=scaler_path,
2037
+ metadata_path=metadata_path,
2038
+ enable_tensorboard=bool(enable_tensorboard),
2039
+ )
2040
+
2041
+ refresh_artifacts(
2042
+ Path(result["model_path"]),
2043
+ Path(result["scaler_path"]),
2044
+ Path(result["metadata_path"]),
2045
+ )
2046
+
2047
+ report_df = classification_report_to_dataframe(
2048
+ result["classification_report"]
2049
+ )
2050
+ confusion_df = confusion_matrix_to_dataframe(
2051
+ result["confusion_matrix"], result["class_names"]
2052
+ )
2053
+ tensorboard_dir = result.get("tensorboard_log_dir")
2054
+ tensorboard_zip = result.get("tensorboard_zip_path")
2055
+
2056
+ architecture = result["model_type"].replace("_", "-").upper()
2057
+ status = (
2058
+ f"Training complete using a {architecture} architecture. "
2059
+ f"{result['num_sequences']} windows derived from "
2060
+ f"{result['num_samples']} rows across {len(available_paths)} file(s)."
2061
+ f" Artifacts saved to:"
2062
+ f"\n• Model: {result['model_path']}\n"
2063
+ f"• Scaler: {result['scaler_path']}\n"
2064
+ f"• Metadata: {result['metadata_path']}"
2065
+ )
2066
+
2067
+ status += f"\nLabel column used: {result.get('label_column', label_column)}"
2068
+
2069
+ if tensorboard_dir:
2070
+ status += (
2071
+ f"\nTensorBoard logs directory: {tensorboard_dir}"
2072
+ f'\nRun `tensorboard --logdir "{tensorboard_dir}"` to inspect the training curves.'
2073
+ "\nDownload the archive below to explore the run offline."
2074
+ )
2075
+
2076
+ if missing_paths:
2077
+ skipped = ", ".join(missing_paths)
2078
+ status = f"⚠️ Skipped missing files: {skipped}\n" + status
2079
+
2080
+ artifact_choices, selected_artifact = gather_artifact_choices(
2081
+ str(base_dir), result["model_path"]
2082
+ )
2083
+
2084
+ return (
2085
+ status,
2086
+ report_df,
2087
+ result["history"],
2088
+ confusion_df,
2089
+ download_button_state(result["model_path"]),
2090
+ download_button_state(result["scaler_path"]),
2091
+ download_button_state(result["metadata_path"]),
2092
+ download_button_state(tensorboard_zip),
2093
+ gr.update(value=result.get("label_column", label_column)),
2094
+ gr.update(
2095
+ choices=artifact_choices, value=selected_artifact
2096
+ ),
2097
+ download_button_state(selected_artifact),
2098
+ )
2099
+ except Exception as exc:
2100
+ artifact_choices, selected_artifact = gather_artifact_choices(
2101
+ str(base_dir)
2102
+ )
2103
+ return (
2104
+ f"Training failed: {exc}",
2105
+ pd.DataFrame(),
2106
+ {},
2107
+ pd.DataFrame(),
2108
+ download_button_state(None),
2109
+ download_button_state(None),
2110
+ download_button_state(None),
2111
+ download_button_state(None),
2112
+ gr.update(),
2113
+ gr.update(
2114
+ choices=artifact_choices, value=selected_artifact
2115
+ ),
2116
+ download_button_state(selected_artifact),
2117
+ )
2118
+
2119
+ def _check_progress(output_dir, model_filename, current_messages):
2120
+ """Check training progress by reading status file and accumulate messages."""
2121
+ model_path = resolve_output_path(
2122
+ output_dir, model_filename, Path(LOCAL_MODEL_FILE).name
2123
+ )
2124
+ status_file = model_path.parent / "training_status.txt"
2125
+ status_message = read_training_status(str(status_file))
2126
+
2127
+ # Add timestamp to the message
2128
+ from datetime import datetime
2129
+
2130
+ timestamp = datetime.now().strftime("%H:%M:%S")
2131
+ new_message = f"[{timestamp}] {status_message}"
2132
+
2133
+ # Accumulate messages, keeping last 50 lines to prevent overflow
2134
+ if current_messages:
2135
+ lines = current_messages.split("\n")
2136
+ lines.append(new_message)
2137
+ # Keep only last 50 lines
2138
+ if len(lines) > 50:
2139
+ lines = lines[-50:]
2140
+ accumulated_messages = "\n".join(lines)
2141
+ else:
2142
+ accumulated_messages = new_message
2143
+
2144
+ return accumulated_messages
2145
+
2146
+ train_button.click(
2147
+ _run_training,
2148
+ inputs=[
2149
+ training_files_state,
2150
+ label_input,
2151
+ model_selector,
2152
+ sequence_length_train,
2153
+ stride_train,
2154
+ validation_train,
2155
+ batch_train,
2156
+ epochs_train,
2157
+ output_directory,
2158
+ model_name,
2159
+ scaler_name,
2160
+ metadata_name,
2161
+ tensorboard_toggle,
2162
+ ],
2163
+ outputs=[
2164
+ training_status,
2165
+ report_output,
2166
+ history_output,
2167
+ confusion_output,
2168
+ model_download_button,
2169
+ scaler_download_button,
2170
+ metadata_download_button,
2171
+ tensorboard_download_button,
2172
+ label_input,
2173
+ artifact_browser,
2174
+ artifact_download_button,
2175
+ ],
2176
+ concurrency_limit=EVENT_CONCURRENCY_LIMIT,
2177
+ )
2178
+
2179
+ progress_button.click(
2180
+ _check_progress,
2181
+ inputs=[output_directory, model_name, progress_messages],
2182
+ outputs=[progress_messages],
2183
+ )
2184
+
2185
+ year_selector.change(
2186
+ on_year_change,
2187
+ inputs=[year_selector],
2188
+ outputs=[
2189
+ month_selector,
2190
+ day_selector,
2191
+ available_files,
2192
+ repo_status,
2193
+ ],
2194
+ concurrency_limit=EVENT_CONCURRENCY_LIMIT,
2195
+ )
2196
+
2197
+ month_selector.change(
2198
+ on_month_change,
2199
+ inputs=[year_selector, month_selector],
2200
+ outputs=[day_selector, available_files, repo_status],
2201
+ concurrency_limit=EVENT_CONCURRENCY_LIMIT,
2202
+ )
2203
+
2204
+ day_selector.change(
2205
+ on_day_change,
2206
+ inputs=[year_selector, month_selector, day_selector],
2207
+ outputs=[available_files, repo_status],
2208
+ concurrency_limit=EVENT_CONCURRENCY_LIMIT,
2209
+ )
2210
+
2211
+ download_button.click(
2212
+ download_selected_files,
2213
+ inputs=[
2214
+ year_selector,
2215
+ month_selector,
2216
+ day_selector,
2217
+ available_files,
2218
+ label_input,
2219
+ ],
2220
+ outputs=[
2221
+ training_files_state,
2222
+ training_files_summary,
2223
+ label_input,
2224
+ dataset_info,
2225
+ available_files,
2226
+ repo_status,
2227
+ ],
2228
+ concurrency_limit=EVENT_CONCURRENCY_LIMIT,
2229
+ )
2230
+
2231
+ year_download_button.click(
2232
+ download_year_bundle,
2233
+ inputs=[year_selector, label_input],
2234
+ outputs=[
2235
+ training_files_state,
2236
+ training_files_summary,
2237
+ label_input,
2238
+ dataset_info,
2239
+ available_files,
2240
+ repo_status,
2241
+ ],
2242
+ concurrency_limit=EVENT_CONCURRENCY_LIMIT,
2243
+ )
2244
+
2245
+ month_download_button.click(
2246
+ download_month_bundle,
2247
+ inputs=[year_selector, month_selector, label_input],
2248
+ outputs=[
2249
+ training_files_state,
2250
+ training_files_summary,
2251
+ label_input,
2252
+ dataset_info,
2253
+ available_files,
2254
+ repo_status,
2255
+ ],
2256
+ concurrency_limit=EVENT_CONCURRENCY_LIMIT,
2257
+ )
2258
+
2259
+ day_download_button.click(
2260
+ download_day_bundle,
2261
+ inputs=[year_selector, month_selector, day_selector, label_input],
2262
+ outputs=[
2263
+ training_files_state,
2264
+ training_files_summary,
2265
+ label_input,
2266
+ dataset_info,
2267
+ available_files,
2268
+ repo_status,
2269
+ ],
2270
+ concurrency_limit=EVENT_CONCURRENCY_LIMIT,
2271
+ )
2272
+
2273
+ def _reload_dataset(current_label):
2274
+ local = load_repository_training_files(
2275
+ current_label, force_refresh=True
2276
+ )
2277
+ remote = refresh_remote_browser(force_refresh=True)
2278
+ return (*local, *remote)
2279
+
2280
+ dataset_refresh.click(
2281
+ _reload_dataset,
2282
+ inputs=[label_input],
2283
+ outputs=[
2284
+ training_files_state,
2285
+ training_files_summary,
2286
+ label_input,
2287
+ dataset_info,
2288
+ year_selector,
2289
+ month_selector,
2290
+ day_selector,
2291
+ available_files,
2292
+ repo_status,
2293
+ ],
2294
+ concurrency_limit=EVENT_CONCURRENCY_LIMIT,
2295
+ )
2296
+
2297
+ clear_cache_button.click(
2298
+ clear_downloaded_cache,
2299
+ inputs=[label_input],
2300
+ outputs=[
2301
+ training_files_state,
2302
+ training_files_summary,
2303
+ label_input,
2304
+ dataset_info,
2305
+ year_selector,
2306
+ month_selector,
2307
+ day_selector,
2308
+ available_files,
2309
+ repo_status,
2310
+ ],
2311
+ concurrency_limit=EVENT_CONCURRENCY_LIMIT,
2312
+ )
2313
+
2314
+ def _initialise_dataset():
2315
+ local = load_repository_training_files(
2316
+ LABEL_COLUMN, force_refresh=False
2317
+ )
2318
+ remote = refresh_remote_browser(force_refresh=False)
2319
+ return (*local, *remote)
2320
+
2321
+ demo.load(
2322
+ _initialise_dataset,
2323
+ inputs=None,
2324
+ outputs=[
2325
+ training_files_state,
2326
+ training_files_summary,
2327
+ label_input,
2328
+ dataset_info,
2329
+ year_selector,
2330
+ month_selector,
2331
+ day_selector,
2332
+ available_files,
2333
+ repo_status,
2334
+ ],
2335
+ queue=False,
2336
+ )
2337
+
2338
+ return demo
2339
+
2340
+
2341
+ # --------------------------------------------------------------------------------------
2342
+ # Launch helpers
2343
+ # --------------------------------------------------------------------------------------
2344
+
2345
+
2346
+ def resolve_server_port() -> int:
2347
+ for env_var in ("PORT", "GRADIO_SERVER_PORT"):
2348
+ value = os.environ.get(env_var)
2349
+ if value:
2350
+ try:
2351
+ return int(value)
2352
+ except ValueError:
2353
+ print(f"Ignoring invalid port value from {env_var}: {value}")
2354
+ return 7860
2355
+
2356
+
2357
+ def main():
2358
+ print("Building Gradio interface...")
2359
+ try:
2360
+ demo = build_interface()
2361
+ print("Interface built successfully")
2362
+ except Exception as e:
2363
+ print(f"Failed to build interface: {e}")
2364
+ import traceback
2365
+
2366
+ traceback.print_exc()
2367
+ return
2368
+
2369
+ print("Setting up queue...")
2370
+ try:
2371
+ demo.queue(max_size=QUEUE_MAX_SIZE)
2372
+ print("Queue configured")
2373
+ except Exception as e:
2374
+ print(f"Failed to configure queue: {e}")
2375
+
2376
+ try:
2377
+ port = resolve_server_port()
2378
+ print(f"Launching Gradio app on port {port}")
2379
+ demo.launch(server_name="0.0.0.0", server_port=port, show_error=True)
2380
+ except OSError as exc:
2381
+ print("Failed to launch on requested port:", exc)
2382
+ try:
2383
+ demo.launch(server_name="0.0.0.0", show_error=True)
2384
+ except Exception as e:
2385
+ print(f"Failed to launch completely: {e}")
2386
+ except Exception as e:
2387
+ print(f"Unexpected launch error: {e}")
2388
+ import traceback
2389
+
2390
+ traceback.print_exc()
2391
+
2392
+
2393
+ if __name__ == "__main__":
2394
+ print("=" * 50)
2395
+ print("PMU Fault Classification App Starting")
2396
+ print(f"Python version: {os.sys.version}")
2397
+ print(f"Working directory: {os.getcwd()}")
2398
+ print(f"HUB_REPO: {HUB_REPO}")
2399
+ print(f"Model available: {MODEL is not None}")
2400
+ print(f"Scaler available: {SCALER is not None}")
2401
+ print("=" * 50)
2402
+ main()
.history/app_20251009232235.py ADDED
@@ -0,0 +1,2402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Gradio front-end for Fault_Classification_PMU_Data models.
2
+
3
+ The application loads a CNN-LSTM model (and accompanying scaler/metadata)
4
+ produced by ``fault_classification_pmu.py`` and exposes a streamlined
5
+ prediction interface optimised for Hugging Face Spaces deployment. It supports
6
+ raw PMU time-series CSV uploads as well as manual comma separated feature
7
+ vectors.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import json
13
+ import os
14
+ import shutil
15
+
16
+ os.environ.setdefault("CUDA_VISIBLE_DEVICES", "-1")
17
+ os.environ.setdefault("TF_CPP_MIN_LOG_LEVEL", "2")
18
+ os.environ.setdefault("TF_ENABLE_ONEDNN_OPTS", "0")
19
+
20
+ import re
21
+ from pathlib import Path
22
+ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
23
+
24
+ import gradio as gr
25
+ import joblib
26
+ import numpy as np
27
+ import pandas as pd
28
+ import requests
29
+ from huggingface_hub import hf_hub_download
30
+ from tensorflow.keras.models import load_model
31
+
32
+ from fault_classification_pmu import (
33
+ DEFAULT_FEATURE_COLUMNS as TRAINING_DEFAULT_FEATURE_COLUMNS,
34
+ LABEL_GUESS_CANDIDATES as TRAINING_LABEL_GUESSES,
35
+ train_from_dataframe,
36
+ )
37
+
38
+ # --------------------------------------------------------------------------------------
39
+ # Configuration
40
+ # --------------------------------------------------------------------------------------
41
+ DEFAULT_FEATURE_COLUMNS: List[str] = list(TRAINING_DEFAULT_FEATURE_COLUMNS)
42
+ DEFAULT_SEQUENCE_LENGTH = 32
43
+ DEFAULT_STRIDE = 4
44
+
45
+ LOCAL_MODEL_FILE = os.environ.get("PMU_MODEL_FILE", "pmu_cnn_lstm_model.keras")
46
+ LOCAL_SCALER_FILE = os.environ.get("PMU_SCALER_FILE", "pmu_feature_scaler.pkl")
47
+ LOCAL_METADATA_FILE = os.environ.get("PMU_METADATA_FILE", "pmu_metadata.json")
48
+
49
+ MODEL_OUTPUT_DIR = Path(os.environ.get("PMU_MODEL_DIR", "model")).resolve()
50
+ MODEL_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
51
+
52
+ HUB_REPO = os.environ.get("PMU_HUB_REPO", "")
53
+ HUB_MODEL_FILENAME = os.environ.get("PMU_HUB_MODEL_FILENAME", LOCAL_MODEL_FILE)
54
+ HUB_SCALER_FILENAME = os.environ.get("PMU_HUB_SCALER_FILENAME", LOCAL_SCALER_FILE)
55
+ HUB_METADATA_FILENAME = os.environ.get("PMU_HUB_METADATA_FILENAME", LOCAL_METADATA_FILE)
56
+
57
+ ENV_MODEL_PATH = "PMU_MODEL_PATH"
58
+ ENV_SCALER_PATH = "PMU_SCALER_PATH"
59
+ ENV_METADATA_PATH = "PMU_METADATA_PATH"
60
+
61
+ # --------------------------------------------------------------------------------------
62
+ # Utility functions for loading artifacts
63
+ # --------------------------------------------------------------------------------------
64
+
65
+
66
+ def download_from_hub(filename: str) -> Optional[Path]:
67
+ if not HUB_REPO or not filename:
68
+ return None
69
+ try:
70
+ print(f"Downloading {filename} from {HUB_REPO} ...")
71
+ # Add timeout to prevent hanging
72
+ path = hf_hub_download(repo_id=HUB_REPO, filename=filename)
73
+ print("Downloaded", path)
74
+ return Path(path)
75
+ except Exception as exc: # pragma: no cover - logging convenience
76
+ print("Failed to download", filename, "from", HUB_REPO, ":", exc)
77
+ print("Continuing without pre-trained model...")
78
+ return None
79
+
80
+
81
+ def resolve_artifact(
82
+ local_name: str, env_var: str, hub_filename: str
83
+ ) -> Optional[Path]:
84
+ print(f"Resolving artifact: {local_name}, env: {env_var}, hub: {hub_filename}")
85
+ candidates = [Path(local_name)] if local_name else []
86
+ if local_name:
87
+ candidates.append(MODEL_OUTPUT_DIR / Path(local_name).name)
88
+ env_value = os.environ.get(env_var)
89
+ if env_value:
90
+ candidates.append(Path(env_value))
91
+
92
+ for candidate in candidates:
93
+ if candidate and candidate.exists():
94
+ print(f"Found local artifact: {candidate}")
95
+ return candidate
96
+
97
+ print(f"No local artifacts found, checking hub...")
98
+ # Only try to download if we have a hub repo configured
99
+ if HUB_REPO:
100
+ return download_from_hub(hub_filename)
101
+ else:
102
+ print("No HUB_REPO configured, skipping download")
103
+ return None
104
+
105
+
106
+ def load_metadata(path: Optional[Path]) -> Dict:
107
+ if path and path.exists():
108
+ try:
109
+ return json.loads(path.read_text())
110
+ except Exception as exc: # pragma: no cover - metadata parsing errors
111
+ print("Failed to read metadata", path, exc)
112
+ return {}
113
+
114
+
115
+ def try_load_scaler(path: Optional[Path]):
116
+ if not path:
117
+ return None
118
+ try:
119
+ scaler = joblib.load(path)
120
+ print("Loaded scaler from", path)
121
+ return scaler
122
+ except Exception as exc:
123
+ print("Failed to load scaler", path, exc)
124
+ return None
125
+
126
+
127
+ # Initialize paths with error handling
128
+ print("Starting application initialization...")
129
+ try:
130
+ MODEL_PATH = resolve_artifact(LOCAL_MODEL_FILE, ENV_MODEL_PATH, HUB_MODEL_FILENAME)
131
+ print(f"Model path resolved: {MODEL_PATH}")
132
+ except Exception as e:
133
+ print(f"Model path resolution failed: {e}")
134
+ MODEL_PATH = None
135
+
136
+ try:
137
+ SCALER_PATH = resolve_artifact(
138
+ LOCAL_SCALER_FILE, ENV_SCALER_PATH, HUB_SCALER_FILENAME
139
+ )
140
+ print(f"Scaler path resolved: {SCALER_PATH}")
141
+ except Exception as e:
142
+ print(f"Scaler path resolution failed: {e}")
143
+ SCALER_PATH = None
144
+
145
+ try:
146
+ METADATA_PATH = resolve_artifact(
147
+ LOCAL_METADATA_FILE, ENV_METADATA_PATH, HUB_METADATA_FILENAME
148
+ )
149
+ print(f"Metadata path resolved: {METADATA_PATH}")
150
+ except Exception as e:
151
+ print(f"Metadata path resolution failed: {e}")
152
+ METADATA_PATH = None
153
+
154
+ try:
155
+ METADATA = load_metadata(METADATA_PATH)
156
+ print(f"Metadata loaded: {len(METADATA)} entries")
157
+ except Exception as e:
158
+ print(f"Metadata loading failed: {e}")
159
+ METADATA = {}
160
+
161
+ # Queuing configuration
162
+ QUEUE_MAX_SIZE = 32
163
+ # Apply a small per-event concurrency limit to avoid relying on the deprecated
164
+ # ``concurrency_count`` parameter when enabling Gradio's request queue.
165
+ EVENT_CONCURRENCY_LIMIT = 2
166
+
167
+
168
+ def try_load_model(path: Optional[Path], model_type: str, model_format: str):
169
+ if not path:
170
+ return None
171
+ try:
172
+ if model_type == "svm" or model_format == "joblib":
173
+ model = joblib.load(path)
174
+ else:
175
+ model = load_model(path)
176
+ print("Loaded model from", path)
177
+ return model
178
+ except Exception as exc: # pragma: no cover - runtime diagnostics
179
+ print("Failed to load model", path, exc)
180
+ return None
181
+
182
+
183
+ FEATURE_COLUMNS: List[str] = list(DEFAULT_FEATURE_COLUMNS)
184
+ LABEL_CLASSES: List[str] = []
185
+ LABEL_COLUMN: str = "Fault"
186
+ SEQUENCE_LENGTH: int = DEFAULT_SEQUENCE_LENGTH
187
+ DEFAULT_WINDOW_STRIDE: int = DEFAULT_STRIDE
188
+ MODEL_TYPE: str = "cnn_lstm"
189
+ MODEL_FORMAT: str = "keras"
190
+
191
+
192
+ def _model_output_path(filename: str) -> str:
193
+ return str(MODEL_OUTPUT_DIR / Path(filename).name)
194
+
195
+
196
+ MODEL_FILENAME_BY_TYPE: Dict[str, str] = {
197
+ "cnn_lstm": Path(LOCAL_MODEL_FILE).name,
198
+ "tcn": "pmu_tcn_model.keras",
199
+ "svm": "pmu_svm_model.joblib",
200
+ }
201
+
202
+ REQUIRED_PMU_COLUMNS: Tuple[str, ...] = tuple(DEFAULT_FEATURE_COLUMNS)
203
+ TRAINING_UPLOAD_DIR = Path(
204
+ os.environ.get("PMU_TRAINING_UPLOAD_DIR", "training_uploads")
205
+ )
206
+ TRAINING_UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
207
+
208
+ TRAINING_DATA_REPO = os.environ.get(
209
+ "PMU_TRAINING_DATA_REPO", "VincentCroft/ThesisModelData"
210
+ )
211
+ TRAINING_DATA_BRANCH = os.environ.get("PMU_TRAINING_DATA_BRANCH", "main")
212
+ TRAINING_DATA_DIR = Path(os.environ.get("PMU_TRAINING_DATA_DIR", "training_dataset"))
213
+ TRAINING_DATA_DIR.mkdir(parents=True, exist_ok=True)
214
+
215
+ GITHUB_CONTENT_CACHE: Dict[str, List[Dict[str, Any]]] = {}
216
+
217
+
218
+ APP_CSS = """
219
+ #available-files-section {
220
+ position: relative;
221
+ display: flex;
222
+ flex-direction: column;
223
+ gap: 0.75rem;
224
+ border-radius: 0.75rem;
225
+ isolation: isolate;
226
+ }
227
+
228
+ #available-files-grid {
229
+ position: static;
230
+ overflow: visible;
231
+ }
232
+
233
+ #available-files-grid .form {
234
+ position: static;
235
+ min-height: 16rem;
236
+ }
237
+
238
+ #available-files-grid .wrap {
239
+ display: grid;
240
+ grid-template-columns: repeat(4, minmax(0, 1fr));
241
+ gap: 0.5rem;
242
+ max-height: 24rem;
243
+ min-height: 16rem;
244
+ overflow-y: auto;
245
+ padding-right: 0.25rem;
246
+ }
247
+
248
+ #available-files-grid .wrap > div {
249
+ min-width: 0;
250
+ }
251
+
252
+ #available-files-grid .wrap label {
253
+ margin: 0;
254
+ display: flex;
255
+ align-items: center;
256
+ padding: 0.45rem 0.65rem;
257
+ border-radius: 0.65rem;
258
+ background-color: rgba(255, 255, 255, 0.05);
259
+ border: 1px solid rgba(255, 255, 255, 0.08);
260
+ transition: background-color 0.2s ease, border-color 0.2s ease;
261
+ min-height: 2.5rem;
262
+ }
263
+
264
+ #available-files-grid .wrap label:hover {
265
+ background-color: rgba(90, 200, 250, 0.16);
266
+ border-color: rgba(90, 200, 250, 0.4);
267
+ }
268
+
269
+ #available-files-grid .wrap label span {
270
+ overflow: hidden;
271
+ text-overflow: ellipsis;
272
+ white-space: nowrap;
273
+ }
274
+
275
+ #available-files-section .gradio-loading,
276
+ #available-files-grid .gradio-loading {
277
+ position: absolute;
278
+ top: 0;
279
+ left: 0;
280
+ right: 0;
281
+ bottom: 0;
282
+ width: 100%;
283
+ height: 100%;
284
+ display: flex;
285
+ align-items: center;
286
+ justify-content: center;
287
+ background: rgba(10, 14, 23, 0.92);
288
+ border-radius: 0.75rem;
289
+ z-index: 999;
290
+ padding: 1.5rem;
291
+ pointer-events: auto;
292
+ }
293
+
294
+ #available-files-grid .gradio-loading > * {
295
+ width: 100%;
296
+ }
297
+
298
+ #available-files-grid .gradio-loading progress,
299
+ #available-files-grid .gradio-loading .progress-bar,
300
+ #available-files-grid .gradio-loading .loading-progress,
301
+ #available-files-grid .gradio-loading [role="progressbar"],
302
+ #available-files-grid .gradio-loading .wrap,
303
+ #available-files-grid .gradio-loading .inner {
304
+ width: 100% !important;
305
+ max-width: none !important;
306
+ }
307
+
308
+ #available-files-grid .gradio-loading .status,
309
+ #available-files-grid .gradio-loading .message,
310
+ #available-files-grid .gradio-loading .label {
311
+ text-align: center;
312
+ }
313
+
314
+ #date-browser-row {
315
+ gap: 0.75rem;
316
+ }
317
+
318
+ #date-browser-row .date-browser-column {
319
+ flex: 1 1 0%;
320
+ min-width: 0;
321
+ }
322
+
323
+ #date-browser-row .date-browser-column > .gradio-dropdown,
324
+ #date-browser-row .date-browser-column > .gradio-button {
325
+ width: 100%;
326
+ }
327
+
328
+ #date-browser-row .date-browser-column > .gradio-dropdown > div {
329
+ width: 100%;
330
+ }
331
+
332
+ #date-browser-row .date-browser-column .gradio-button {
333
+ justify-content: center;
334
+ }
335
+
336
+ #training-files-summary textarea {
337
+ max-height: 12rem;
338
+ overflow-y: auto;
339
+ }
340
+
341
+ #download-selected-button {
342
+ width: 100%;
343
+ position: relative;
344
+ z-index: 0;
345
+ }
346
+
347
+ #download-selected-button .gradio-button {
348
+ width: 100%;
349
+ justify-content: center;
350
+ }
351
+
352
+ #artifact-download-row {
353
+ gap: 0.75rem;
354
+ }
355
+
356
+ #artifact-download-row .artifact-download-button {
357
+ flex: 1 1 0%;
358
+ min-width: 0;
359
+ }
360
+
361
+ #artifact-download-row .artifact-download-button .gradio-button {
362
+ width: 100%;
363
+ justify-content: center;
364
+ }
365
+ """
366
+
367
+
368
+ def _github_cache_key(path: str) -> str:
369
+ return path or "__root__"
370
+
371
+
372
+ def _github_api_url(path: str) -> str:
373
+ clean_path = path.strip("/")
374
+ base = f"https://api.github.com/repos/{TRAINING_DATA_REPO}/contents"
375
+ if clean_path:
376
+ return f"{base}/{clean_path}?ref={TRAINING_DATA_BRANCH}"
377
+ return f"{base}?ref={TRAINING_DATA_BRANCH}"
378
+
379
+
380
+ def list_remote_directory(
381
+ path: str = "", *, force_refresh: bool = False
382
+ ) -> List[Dict[str, Any]]:
383
+ key = _github_cache_key(path)
384
+ if not force_refresh and key in GITHUB_CONTENT_CACHE:
385
+ return GITHUB_CONTENT_CACHE[key]
386
+
387
+ url = _github_api_url(path)
388
+ response = requests.get(url, timeout=30)
389
+ if response.status_code != 200:
390
+ raise RuntimeError(
391
+ f"GitHub API request failed for `{path or '.'}` (status {response.status_code})."
392
+ )
393
+
394
+ payload = response.json()
395
+ if not isinstance(payload, list):
396
+ raise RuntimeError(
397
+ "Unexpected GitHub API payload. Expected a directory listing."
398
+ )
399
+
400
+ GITHUB_CONTENT_CACHE[key] = payload
401
+ return payload
402
+
403
+
404
+ def list_remote_years(force_refresh: bool = False) -> List[str]:
405
+ entries = list_remote_directory("", force_refresh=force_refresh)
406
+ years = [item["name"] for item in entries if item.get("type") == "dir"]
407
+ return sorted(years)
408
+
409
+
410
+ def list_remote_months(year: str, *, force_refresh: bool = False) -> List[str]:
411
+ if not year:
412
+ return []
413
+ entries = list_remote_directory(year, force_refresh=force_refresh)
414
+ months = [item["name"] for item in entries if item.get("type") == "dir"]
415
+ return sorted(months)
416
+
417
+
418
+ def list_remote_days(
419
+ year: str, month: str, *, force_refresh: bool = False
420
+ ) -> List[str]:
421
+ if not year or not month:
422
+ return []
423
+ entries = list_remote_directory(f"{year}/{month}", force_refresh=force_refresh)
424
+ days = [item["name"] for item in entries if item.get("type") == "dir"]
425
+ return sorted(days)
426
+
427
+
428
+ def list_remote_files(
429
+ year: str, month: str, day: str, *, force_refresh: bool = False
430
+ ) -> List[str]:
431
+ if not year or not month or not day:
432
+ return []
433
+ entries = list_remote_directory(
434
+ f"{year}/{month}/{day}", force_refresh=force_refresh
435
+ )
436
+ files = [item["name"] for item in entries if item.get("type") == "file"]
437
+ return sorted(files)
438
+
439
+
440
+ def download_repository_file(year: str, month: str, day: str, filename: str) -> Path:
441
+ if not filename:
442
+ raise ValueError("Filename cannot be empty when downloading repository data.")
443
+
444
+ relative_parts = [part for part in (year, month, day, filename) if part]
445
+ if len(relative_parts) < 4:
446
+ raise ValueError("Provide year, month, day, and filename to download a CSV.")
447
+
448
+ relative_path = "/".join(relative_parts)
449
+ raw_url = (
450
+ f"https://raw.githubusercontent.com/{TRAINING_DATA_REPO}/"
451
+ f"{TRAINING_DATA_BRANCH}/{relative_path}"
452
+ )
453
+
454
+ response = requests.get(raw_url, stream=True, timeout=120)
455
+ if response.status_code != 200:
456
+ raise RuntimeError(
457
+ f"Failed to download `{relative_path}` (status {response.status_code})."
458
+ )
459
+
460
+ target_dir = TRAINING_DATA_DIR.joinpath(year, month, day)
461
+ target_dir.mkdir(parents=True, exist_ok=True)
462
+ target_path = target_dir / filename
463
+
464
+ with open(target_path, "wb") as handle:
465
+ for chunk in response.iter_content(chunk_size=1 << 20):
466
+ if chunk:
467
+ handle.write(chunk)
468
+
469
+ return target_path
470
+
471
+
472
+ def _normalise_header(name: str) -> str:
473
+ return str(name).strip().lower()
474
+
475
+
476
+ def guess_label_from_columns(
477
+ columns: Sequence[str], preferred: Optional[str] = None
478
+ ) -> Optional[str]:
479
+ if not columns:
480
+ return preferred
481
+
482
+ lookup = {_normalise_header(col): str(col) for col in columns}
483
+
484
+ if preferred:
485
+ preferred_stripped = preferred.strip()
486
+ for col in columns:
487
+ if str(col).strip() == preferred_stripped:
488
+ return str(col)
489
+ preferred_norm = _normalise_header(preferred)
490
+ if preferred_norm in lookup:
491
+ return lookup[preferred_norm]
492
+
493
+ for guess in TRAINING_LABEL_GUESSES:
494
+ guess_norm = _normalise_header(guess)
495
+ if guess_norm in lookup:
496
+ return lookup[guess_norm]
497
+
498
+ for col in columns:
499
+ if _normalise_header(col).startswith("fault"):
500
+ return str(col)
501
+
502
+ return str(columns[0])
503
+
504
+
505
+ def summarise_training_files(paths: Sequence[str], notes: Sequence[str]) -> str:
506
+ lines = [Path(path).name for path in paths]
507
+ lines.extend(notes)
508
+ return "\n".join(lines) if lines else "No training files available."
509
+
510
+
511
+ def read_training_status(status_file_path: str) -> str:
512
+ """Read the current training status from file."""
513
+ try:
514
+ if Path(status_file_path).exists():
515
+ with open(status_file_path, "r") as f:
516
+ return f.read().strip()
517
+ except Exception:
518
+ pass
519
+ return "Training status unavailable"
520
+
521
+
522
+ def _persist_uploaded_file(file_obj) -> Optional[Path]:
523
+ if file_obj is None:
524
+ return None
525
+
526
+ if isinstance(file_obj, (str, Path)):
527
+ source = Path(file_obj)
528
+ original_name = source.name
529
+ else:
530
+ source = Path(getattr(file_obj, "name", "") or getattr(file_obj, "path", ""))
531
+ original_name = getattr(file_obj, "orig_name", source.name) or source.name
532
+ if not source or not source.exists():
533
+ return None
534
+
535
+ original_name = Path(original_name).name or source.name
536
+
537
+ base_path = Path(original_name)
538
+ destination = TRAINING_UPLOAD_DIR / base_path.name
539
+ counter = 1
540
+ while destination.exists():
541
+ suffix = base_path.suffix or ".csv"
542
+ destination = TRAINING_UPLOAD_DIR / f"{base_path.stem}_{counter}{suffix}"
543
+ counter += 1
544
+
545
+ shutil.copy2(source, destination)
546
+ return destination
547
+
548
+
549
+ def prepare_training_paths(
550
+ paths: Sequence[str], current_label: str, cleanup_missing: bool = False
551
+ ):
552
+ valid_paths: List[str] = []
553
+ notes: List[str] = []
554
+ columns_map: Dict[str, str] = {}
555
+ for path in paths:
556
+ try:
557
+ df = load_measurement_csv(path)
558
+ except Exception as exc: # pragma: no cover - user file diagnostics
559
+ notes.append(f"⚠️ Skipped {Path(path).name}: {exc}")
560
+ if cleanup_missing:
561
+ try:
562
+ Path(path).unlink(missing_ok=True)
563
+ except Exception:
564
+ pass
565
+ continue
566
+ valid_paths.append(str(path))
567
+ for col in df.columns:
568
+ columns_map[_normalise_header(col)] = str(col)
569
+
570
+ summary = summarise_training_files(valid_paths, notes)
571
+ preferred = current_label or LABEL_COLUMN
572
+ dropdown_choices = (
573
+ sorted(columns_map.values()) if columns_map else [preferred or LABEL_COLUMN]
574
+ )
575
+ guessed = guess_label_from_columns(dropdown_choices, preferred)
576
+ dropdown_value = guessed or preferred or LABEL_COLUMN
577
+
578
+ return (
579
+ valid_paths,
580
+ summary,
581
+ gr.update(choices=dropdown_choices, value=dropdown_value),
582
+ )
583
+
584
+
585
+ def append_training_files(new_files, existing_paths: Sequence[str], current_label: str):
586
+ if isinstance(existing_paths, (str, Path)):
587
+ paths: List[str] = [str(existing_paths)]
588
+ elif existing_paths is None:
589
+ paths = []
590
+ else:
591
+ paths = list(existing_paths)
592
+ if new_files:
593
+ for file in new_files:
594
+ persisted = _persist_uploaded_file(file)
595
+ if persisted is None:
596
+ continue
597
+ path_str = str(persisted)
598
+ if path_str not in paths:
599
+ paths.append(path_str)
600
+
601
+ return prepare_training_paths(paths, current_label, cleanup_missing=True)
602
+
603
+
604
+ def load_repository_training_files(current_label: str, force_refresh: bool = False):
605
+ if force_refresh:
606
+ # Clearing the cache is enough because downloads are now on-demand.
607
+ for cached in list(TRAINING_DATA_DIR.glob("*")):
608
+ # On refresh we keep previously downloaded files; no deletion required.
609
+ # The flag triggers downstream UI updates only.
610
+ break
611
+
612
+ csv_paths = sorted(
613
+ str(path) for path in TRAINING_DATA_DIR.rglob("*.csv") if path.is_file()
614
+ )
615
+ if not csv_paths:
616
+ message = (
617
+ "No local database CSVs are available yet. Use the database browser "
618
+ "below to download specific days before training."
619
+ )
620
+ default_label = current_label or LABEL_COLUMN or "Fault"
621
+ return (
622
+ [],
623
+ message,
624
+ gr.update(choices=[default_label], value=default_label),
625
+ message,
626
+ )
627
+
628
+ valid_paths, summary, label_update = prepare_training_paths(
629
+ csv_paths, current_label, cleanup_missing=False
630
+ )
631
+
632
+ info = (
633
+ f"Ready with {len(valid_paths)} CSV file(s) cached locally under "
634
+ f"the database cache `{TRAINING_DATA_DIR}`."
635
+ )
636
+
637
+ return valid_paths, summary, label_update, info
638
+
639
+
640
+ def refresh_remote_browser(force_refresh: bool = False):
641
+ if force_refresh:
642
+ GITHUB_CONTENT_CACHE.clear()
643
+ try:
644
+ years = list_remote_years(force_refresh=force_refresh)
645
+ if years:
646
+ message = "Select a year, month, and day to list available CSV files."
647
+ else:
648
+ message = (
649
+ "⚠️ No directories were found in the database root. Verify the upstream "
650
+ "structure."
651
+ )
652
+ return (
653
+ gr.update(choices=years, value=None),
654
+ gr.update(choices=[], value=None),
655
+ gr.update(choices=[], value=None),
656
+ gr.update(choices=[], value=[]),
657
+ message,
658
+ )
659
+ except Exception as exc:
660
+ return (
661
+ gr.update(choices=[], value=None),
662
+ gr.update(choices=[], value=None),
663
+ gr.update(choices=[], value=None),
664
+ gr.update(choices=[], value=[]),
665
+ f"⚠️ Failed to query database: {exc}",
666
+ )
667
+
668
+
669
+ def on_year_change(year: Optional[str]):
670
+ if not year:
671
+ return (
672
+ gr.update(choices=[], value=None),
673
+ gr.update(choices=[], value=None),
674
+ gr.update(choices=[], value=[]),
675
+ "Select a year to continue.",
676
+ )
677
+ try:
678
+ months = list_remote_months(year)
679
+ message = (
680
+ f"Year `{year}` selected. Choose a month to drill down."
681
+ if months
682
+ else f"⚠️ No months available under `{year}`."
683
+ )
684
+ return (
685
+ gr.update(choices=months, value=None),
686
+ gr.update(choices=[], value=None),
687
+ gr.update(choices=[], value=[]),
688
+ message,
689
+ )
690
+ except Exception as exc:
691
+ return (
692
+ gr.update(choices=[], value=None),
693
+ gr.update(choices=[], value=None),
694
+ gr.update(choices=[], value=[]),
695
+ f"⚠️ Failed to list months: {exc}",
696
+ )
697
+
698
+
699
+ def on_month_change(year: Optional[str], month: Optional[str]):
700
+ if not year or not month:
701
+ return (
702
+ gr.update(choices=[], value=None),
703
+ gr.update(choices=[], value=[]),
704
+ "Select a month to continue.",
705
+ )
706
+ try:
707
+ days = list_remote_days(year, month)
708
+ message = (
709
+ f"Month `{year}/{month}` ready. Pick a day to view files."
710
+ if days
711
+ else f"⚠️ No day folders found under `{year}/{month}`."
712
+ )
713
+ return (
714
+ gr.update(choices=days, value=None),
715
+ gr.update(choices=[], value=[]),
716
+ message,
717
+ )
718
+ except Exception as exc:
719
+ return (
720
+ gr.update(choices=[], value=None),
721
+ gr.update(choices=[], value=[]),
722
+ f"⚠️ Failed to list days: {exc}",
723
+ )
724
+
725
+
726
+ def on_day_change(year: Optional[str], month: Optional[str], day: Optional[str]):
727
+ if not year or not month or not day:
728
+ return (
729
+ gr.update(choices=[], value=[]),
730
+ "Select a day to load file names.",
731
+ )
732
+ try:
733
+ files = list_remote_files(year, month, day)
734
+ message = (
735
+ f"{len(files)} file(s) available for `{year}/{month}/{day}`."
736
+ if files
737
+ else f"⚠️ No CSV files found under `{year}/{month}/{day}`."
738
+ )
739
+ return (
740
+ gr.update(choices=files, value=[]),
741
+ message,
742
+ )
743
+ except Exception as exc:
744
+ return (
745
+ gr.update(choices=[], value=[]),
746
+ f"⚠️ Failed to list files: {exc}",
747
+ )
748
+
749
+
750
+ def download_selected_files(
751
+ year: Optional[str],
752
+ month: Optional[str],
753
+ day: Optional[str],
754
+ filenames: Sequence[str],
755
+ current_label: str,
756
+ ):
757
+ if not filenames:
758
+ message = "Select at least one CSV before downloading."
759
+ local = load_repository_training_files(current_label)
760
+ return (*local, gr.update(), message)
761
+
762
+ success: List[str] = []
763
+ notes: List[str] = []
764
+ for filename in filenames:
765
+ try:
766
+ path = download_repository_file(
767
+ year or "", month or "", day or "", filename
768
+ )
769
+ success.append(str(path))
770
+ except Exception as exc:
771
+ notes.append(f"⚠️ {filename}: {exc}")
772
+
773
+ local = load_repository_training_files(current_label)
774
+
775
+ message_lines = []
776
+ if success:
777
+ message_lines.append(
778
+ f"Downloaded {len(success)} file(s) to the database cache `{TRAINING_DATA_DIR}`."
779
+ )
780
+ if notes:
781
+ message_lines.extend(notes)
782
+ if not message_lines:
783
+ message_lines.append("No files were downloaded.")
784
+
785
+ return (*local, gr.update(value=[]), "\n".join(message_lines))
786
+
787
+
788
+ def download_day_bundle(
789
+ year: Optional[str],
790
+ month: Optional[str],
791
+ day: Optional[str],
792
+ current_label: str,
793
+ ):
794
+ if not (year and month and day):
795
+ local = load_repository_training_files(current_label)
796
+ return (
797
+ *local,
798
+ gr.update(),
799
+ "Select a year, month, and day before downloading an entire day.",
800
+ )
801
+
802
+ try:
803
+ files = list_remote_files(year, month, day)
804
+ except Exception as exc:
805
+ local = load_repository_training_files(current_label)
806
+ return (
807
+ *local,
808
+ gr.update(),
809
+ f"⚠️ Failed to list CSVs for `{year}/{month}/{day}`: {exc}",
810
+ )
811
+
812
+ if not files:
813
+ local = load_repository_training_files(current_label)
814
+ return (
815
+ *local,
816
+ gr.update(),
817
+ f"No CSV files were found for `{year}/{month}/{day}`.",
818
+ )
819
+
820
+ result = list(download_selected_files(year, month, day, files, current_label))
821
+ result[-1] = (
822
+ f"Downloaded all {len(files)} CSV file(s) for `{year}/{month}/{day}`.\n"
823
+ f"{result[-1]}"
824
+ )
825
+ return tuple(result)
826
+
827
+
828
+ def download_month_bundle(
829
+ year: Optional[str], month: Optional[str], current_label: str
830
+ ):
831
+ if not (year and month):
832
+ local = load_repository_training_files(current_label)
833
+ return (
834
+ *local,
835
+ gr.update(),
836
+ "Select a year and month before downloading an entire month.",
837
+ )
838
+
839
+ try:
840
+ days = list_remote_days(year, month)
841
+ except Exception as exc:
842
+ local = load_repository_training_files(current_label)
843
+ return (
844
+ *local,
845
+ gr.update(),
846
+ f"⚠️ Failed to enumerate days for `{year}/{month}`: {exc}",
847
+ )
848
+
849
+ if not days:
850
+ local = load_repository_training_files(current_label)
851
+ return (
852
+ *local,
853
+ gr.update(),
854
+ f"No day folders were found for `{year}/{month}`.",
855
+ )
856
+
857
+ downloaded = 0
858
+ notes: List[str] = []
859
+ for day in days:
860
+ try:
861
+ files = list_remote_files(year, month, day)
862
+ except Exception as exc:
863
+ notes.append(f"⚠️ Failed to list `{year}/{month}/{day}`: {exc}")
864
+ continue
865
+ if not files:
866
+ notes.append(f"⚠️ No CSV files in `{year}/{month}/{day}`.")
867
+ continue
868
+ for filename in files:
869
+ try:
870
+ download_repository_file(year, month, day, filename)
871
+ downloaded += 1
872
+ except Exception as exc:
873
+ notes.append(f"⚠️ {year}/{month}/{day}/{filename}: {exc}")
874
+
875
+ local = load_repository_training_files(current_label)
876
+ message_lines = []
877
+ if downloaded:
878
+ message_lines.append(
879
+ f"Downloaded {downloaded} CSV file(s) for `{year}/{month}` into the "
880
+ f"database cache `{TRAINING_DATA_DIR}`."
881
+ )
882
+ message_lines.extend(notes)
883
+ if not message_lines:
884
+ message_lines.append("No files were downloaded.")
885
+
886
+ return (*local, gr.update(value=[]), "\n".join(message_lines))
887
+
888
+
889
+ def download_year_bundle(year: Optional[str], current_label: str):
890
+ if not year:
891
+ local = load_repository_training_files(current_label)
892
+ return (
893
+ *local,
894
+ gr.update(),
895
+ "Select a year before downloading an entire year of CSVs.",
896
+ )
897
+
898
+ try:
899
+ months = list_remote_months(year)
900
+ except Exception as exc:
901
+ local = load_repository_training_files(current_label)
902
+ return (
903
+ *local,
904
+ gr.update(),
905
+ f"⚠️ Failed to enumerate months for `{year}`: {exc}",
906
+ )
907
+
908
+ if not months:
909
+ local = load_repository_training_files(current_label)
910
+ return (
911
+ *local,
912
+ gr.update(),
913
+ f"No month folders were found for `{year}`.",
914
+ )
915
+
916
+ downloaded = 0
917
+ notes: List[str] = []
918
+ for month in months:
919
+ try:
920
+ days = list_remote_days(year, month)
921
+ except Exception as exc:
922
+ notes.append(f"⚠️ Failed to list `{year}/{month}`: {exc}")
923
+ continue
924
+ if not days:
925
+ notes.append(f"⚠️ No day folders in `{year}/{month}`.")
926
+ continue
927
+ for day in days:
928
+ try:
929
+ files = list_remote_files(year, month, day)
930
+ except Exception as exc:
931
+ notes.append(f"⚠️ Failed to list `{year}/{month}/{day}`: {exc}")
932
+ continue
933
+ if not files:
934
+ notes.append(f"⚠️ No CSV files in `{year}/{month}/{day}`.")
935
+ continue
936
+ for filename in files:
937
+ try:
938
+ download_repository_file(year, month, day, filename)
939
+ downloaded += 1
940
+ except Exception as exc:
941
+ notes.append(f"⚠️ {year}/{month}/{day}/{filename}: {exc}")
942
+
943
+ local = load_repository_training_files(current_label)
944
+ message_lines = []
945
+ if downloaded:
946
+ message_lines.append(
947
+ f"Downloaded {downloaded} CSV file(s) for `{year}` into the "
948
+ f"database cache `{TRAINING_DATA_DIR}`."
949
+ )
950
+ message_lines.extend(notes)
951
+ if not message_lines:
952
+ message_lines.append("No files were downloaded.")
953
+
954
+ return (*local, gr.update(value=[]), "\n".join(message_lines))
955
+
956
+
957
+ def clear_downloaded_cache(current_label: str):
958
+ status_message = ""
959
+ try:
960
+ if TRAINING_DATA_DIR.exists():
961
+ shutil.rmtree(TRAINING_DATA_DIR)
962
+ TRAINING_DATA_DIR.mkdir(parents=True, exist_ok=True)
963
+ status_message = (
964
+ f"Cleared all downloaded CSVs from database cache `{TRAINING_DATA_DIR}`."
965
+ )
966
+ except Exception as exc:
967
+ status_message = f"⚠️ Failed to clear database cache: {exc}"
968
+
969
+ local = load_repository_training_files(current_label, force_refresh=True)
970
+ remote = list(refresh_remote_browser(force_refresh=False))
971
+ if status_message:
972
+ previous = remote[-1]
973
+ if isinstance(previous, str) and previous:
974
+ remote[-1] = f"{status_message}\n{previous}"
975
+ else:
976
+ remote[-1] = status_message
977
+
978
+ return (*local, *remote)
979
+
980
+
981
+ def normalise_output_directory(directory: Optional[str]) -> Path:
982
+ base = Path(directory or MODEL_OUTPUT_DIR)
983
+ base = base.expanduser()
984
+ if not base.is_absolute():
985
+ base = (Path.cwd() / base).resolve()
986
+ return base
987
+
988
+
989
+ def resolve_output_path(
990
+ directory: Optional[Union[Path, str]], filename: Optional[str], fallback: str
991
+ ) -> Path:
992
+ if isinstance(directory, Path):
993
+ base = directory
994
+ else:
995
+ base = normalise_output_directory(directory)
996
+ candidate = Path(filename or "").expanduser()
997
+ if str(candidate):
998
+ if candidate.is_absolute():
999
+ return candidate
1000
+ return (base / candidate).resolve()
1001
+ return (base / fallback).resolve()
1002
+
1003
+
1004
+ ARTIFACT_FILE_EXTENSIONS: Tuple[str, ...] = (
1005
+ ".keras",
1006
+ ".h5",
1007
+ ".joblib",
1008
+ ".pkl",
1009
+ ".json",
1010
+ ".onnx",
1011
+ ".zip",
1012
+ ".txt",
1013
+ )
1014
+
1015
+
1016
+ def gather_directory_choices(current: Optional[str]) -> Tuple[List[str], str]:
1017
+ base = normalise_output_directory(current or str(MODEL_OUTPUT_DIR))
1018
+ candidates = {str(base)}
1019
+ try:
1020
+ for candidate in base.parent.iterdir():
1021
+ if candidate.is_dir():
1022
+ candidates.add(str(candidate.resolve()))
1023
+ except Exception:
1024
+ pass
1025
+ return sorted(candidates), str(base)
1026
+
1027
+
1028
+ def gather_artifact_choices(
1029
+ directory: Optional[str], selection: Optional[str] = None
1030
+ ) -> Tuple[List[Tuple[str, str]], Optional[str]]:
1031
+ base = normalise_output_directory(directory)
1032
+ choices: List[Tuple[str, str]] = []
1033
+ selected_value: Optional[str] = None
1034
+ if base.exists():
1035
+ try:
1036
+ artifacts = sorted(
1037
+ [
1038
+ path
1039
+ for path in base.iterdir()
1040
+ if path.is_file()
1041
+ and (
1042
+ not ARTIFACT_FILE_EXTENSIONS
1043
+ or path.suffix.lower() in ARTIFACT_FILE_EXTENSIONS
1044
+ )
1045
+ ],
1046
+ key=lambda path: path.name.lower(),
1047
+ )
1048
+ choices = [(artifact.name, str(artifact)) for artifact in artifacts]
1049
+ except Exception:
1050
+ choices = []
1051
+
1052
+ if selection and any(value == selection for _, value in choices):
1053
+ selected_value = selection
1054
+ elif choices:
1055
+ selected_value = choices[0][1]
1056
+
1057
+ return choices, selected_value
1058
+
1059
+
1060
+ def download_button_state(path: Optional[Union[str, Path]]):
1061
+ if not path:
1062
+ return gr.update(value=None, visible=False)
1063
+ candidate = Path(path)
1064
+ if candidate.exists():
1065
+ return gr.update(value=str(candidate), visible=True)
1066
+ return gr.update(value=None, visible=False)
1067
+
1068
+
1069
+ def clear_training_files():
1070
+ default_label = LABEL_COLUMN or "Fault"
1071
+ for cached_file in TRAINING_UPLOAD_DIR.glob("*"):
1072
+ try:
1073
+ if cached_file.is_file():
1074
+ cached_file.unlink(missing_ok=True)
1075
+ except Exception:
1076
+ pass
1077
+ return (
1078
+ [],
1079
+ "No training files selected.",
1080
+ gr.update(choices=[default_label], value=default_label),
1081
+ gr.update(value=None),
1082
+ )
1083
+
1084
+
1085
+ PROJECT_OVERVIEW_MD = """
1086
+ ## Project Overview
1087
+
1088
+ This project focuses on classifying faults in electrical transmission lines and
1089
+ grid-connected photovoltaic (PV) systems by combining ensemble learning
1090
+ techniques with deep neural architectures.
1091
+
1092
+ ## Datasets
1093
+
1094
+ ### Transmission Line Fault Dataset
1095
+ - 134,406 samples collected from Phasor Measurement Units (PMUs)
1096
+ - 14 monitored channels covering currents, voltages, magnitudes, frequency, and phase angles
1097
+ - Labels span symmetrical and asymmetrical faults: NF, L-G, LL, LL-G, LLL, and LLL-G
1098
+ - Time span: 0 to 5.7 seconds with high-frequency sampling
1099
+
1100
+ ### Grid-Connected PV System Fault Dataset
1101
+ - 2,163,480 samples from 16 experimental scenarios
1102
+ - 14 features including PV array measurements (Ipv, Vpv, Vdc), three-phase currents/voltages, aggregate magnitudes (Iabc, Vabc), and frequency indicators (If, Vf)
1103
+ - Captures array, inverter, grid anomaly, feedback sensor, and MPPT controller faults at 9.9989 μs sampling intervals
1104
+
1105
+ ## Data Format Quick Reference
1106
+
1107
+ Each measurement file may be comma or tab separated and typically exposes the
1108
+ following ordered columns:
1109
+
1110
+ 1. `Timestamp`
1111
+ 2. `[325] UPMU_SUB22:FREQ` – system frequency (Hz)
1112
+ 3. `[326] UPMU_SUB22:DFDT` – frequency rate-of-change
1113
+ 4. `[327] UPMU_SUB22:FLAG` – PMU status flag
1114
+ 5. `[328] UPMU_SUB22-L1:MAG` – phase A voltage magnitude
1115
+ 6. `[329] UPMU_SUB22-L1:ANG` – phase A voltage angle
1116
+ 7. `[330] UPMU_SUB22-L2:MAG` – phase B voltage magnitude
1117
+ 8. `[331] UPMU_SUB22-L2:ANG` – phase B voltage angle
1118
+ 9. `[332] UPMU_SUB22-L3:MAG` – phase C voltage magnitude
1119
+ 10. `[333] UPMU_SUB22-L3:ANG` – phase C voltage angle
1120
+ 11. `[334] UPMU_SUB22-C1:MAG` – phase A current magnitude
1121
+ 12. `[335] UPMU_SUB22-C1:ANG` – phase A current angle
1122
+ 13. `[336] UPMU_SUB22-C2:MAG` – phase B current magnitude
1123
+ 14. `[337] UPMU_SUB22-C2:ANG` – phase B current angle
1124
+ 15. `[338] UPMU_SUB22-C3:MAG` – phase C current magnitude
1125
+ 16. `[339] UPMU_SUB22-C3:ANG` – phase C current angle
1126
+
1127
+ The training tab automatically downloads the latest CSV exports from the
1128
+ `VincentCroft/ThesisModelData` repository and concatenates them before building
1129
+ sliding windows.
1130
+
1131
+ ## Models Developed
1132
+
1133
+ 1. **Support Vector Machine (SVM)** – provides the classical machine learning baseline with balanced accuracy across both datasets (85% PMU / 83% PV).
1134
+ 2. **CNN-LSTM** – couples convolutional feature extraction with temporal memory, achieving 92% PMU / 89% PV accuracy.
1135
+ 3. **Temporal Convolutional Network (TCN)** – leverages dilated convolutions for long-range context and delivers the best trade-off between accuracy and training time (94% PMU / 91% PV).
1136
+
1137
+ ## Results Summary
1138
+
1139
+ - **Transmission Line Fault Classification**: SVM 85%, CNN-LSTM 92%, TCN 94%
1140
+ - **PV System Fault Classification**: SVM 83%, CNN-LSTM 89%, TCN 91%
1141
+
1142
+ Use the **Inference** tab to score new PMU/PV windows and the **Training** tab to
1143
+ fine-tune or retrain any of the supported models directly within Hugging Face
1144
+ Spaces. The logs panel will surface TensorBoard archives whenever deep-learning
1145
+ models are trained.
1146
+ """
1147
+
1148
+
1149
+ def load_measurement_csv(path: str) -> pd.DataFrame:
1150
+ """Read a PMU/PV measurement file with flexible separators and column mapping."""
1151
+
1152
+ try:
1153
+ df = pd.read_csv(path, sep=None, engine="python", encoding="utf-8-sig")
1154
+ except Exception:
1155
+ df = None
1156
+ for separator in ("\t", ",", ";"):
1157
+ try:
1158
+ df = pd.read_csv(
1159
+ path, sep=separator, engine="python", encoding="utf-8-sig"
1160
+ )
1161
+ break
1162
+ except Exception:
1163
+ df = None
1164
+ if df is None:
1165
+ raise
1166
+
1167
+ # Clean column names
1168
+ df.columns = [str(col).strip() for col in df.columns]
1169
+
1170
+ print(f"Loaded CSV with {len(df)} rows and {len(df.columns)} columns")
1171
+ print(f"Columns: {list(df.columns)}")
1172
+ print(f"Data shape: {df.shape}")
1173
+
1174
+ # Check if we have enough data for training
1175
+ if len(df) < 100:
1176
+ print(
1177
+ f"Warning: Only {len(df)} rows of data. Recommend at least 1000 rows for effective training."
1178
+ )
1179
+
1180
+ # Check for label column
1181
+ has_label = any(
1182
+ col.lower() in ["fault", "label", "class", "target"] for col in df.columns
1183
+ )
1184
+ if not has_label:
1185
+ print(
1186
+ "Warning: No label column found. Adding dummy 'Fault' column with value 'Normal' for all samples."
1187
+ )
1188
+ df["Fault"] = "Normal" # Add dummy label for training
1189
+
1190
+ # Create column mapping - map similar column names to expected format
1191
+ column_mapping = {}
1192
+ expected_cols = list(REQUIRED_PMU_COLUMNS)
1193
+
1194
+ # If we have at least the right number of numeric columns after Timestamp, use positional mapping
1195
+ if "Timestamp" in df.columns:
1196
+ numeric_cols = [col for col in df.columns if col != "Timestamp"]
1197
+ if len(numeric_cols) >= len(expected_cols):
1198
+ # Map by position (after Timestamp)
1199
+ for i, expected_col in enumerate(expected_cols):
1200
+ if i < len(numeric_cols):
1201
+ column_mapping[numeric_cols[i]] = expected_col
1202
+
1203
+ # Rename columns to match expected format
1204
+ df = df.rename(columns=column_mapping)
1205
+
1206
+ # Check if we have the required columns after mapping
1207
+ missing = [col for col in REQUIRED_PMU_COLUMNS if col not in df.columns]
1208
+ if missing:
1209
+ # If still missing, try a more flexible approach
1210
+ available_numeric = df.select_dtypes(include=[np.number]).columns.tolist()
1211
+ if len(available_numeric) >= len(expected_cols):
1212
+ # Use the first N numeric columns
1213
+ for i, expected_col in enumerate(expected_cols):
1214
+ if i < len(available_numeric):
1215
+ if available_numeric[i] not in df.columns:
1216
+ continue
1217
+ df = df.rename(columns={available_numeric[i]: expected_col})
1218
+
1219
+ # Recheck missing columns
1220
+ missing = [col for col in REQUIRED_PMU_COLUMNS if col not in df.columns]
1221
+
1222
+ if missing:
1223
+ missing_str = ", ".join(missing)
1224
+ available_str = ", ".join(df.columns.tolist())
1225
+ raise ValueError(
1226
+ f"Missing required PMU feature columns: {missing_str}. "
1227
+ f"Available columns: {available_str}. "
1228
+ "Please ensure your CSV has the correct format with Timestamp followed by PMU measurements."
1229
+ )
1230
+
1231
+ return df
1232
+
1233
+
1234
+ def apply_metadata(metadata: Dict[str, Any]) -> None:
1235
+ global FEATURE_COLUMNS, LABEL_CLASSES, LABEL_COLUMN, SEQUENCE_LENGTH, DEFAULT_WINDOW_STRIDE, MODEL_TYPE, MODEL_FORMAT
1236
+ FEATURE_COLUMNS = [
1237
+ str(col) for col in metadata.get("feature_columns", DEFAULT_FEATURE_COLUMNS)
1238
+ ]
1239
+ LABEL_CLASSES = [str(label) for label in metadata.get("label_classes", [])]
1240
+ LABEL_COLUMN = str(metadata.get("label_column", "Fault"))
1241
+ SEQUENCE_LENGTH = int(metadata.get("sequence_length", DEFAULT_SEQUENCE_LENGTH))
1242
+ DEFAULT_WINDOW_STRIDE = int(metadata.get("stride", DEFAULT_STRIDE))
1243
+ MODEL_TYPE = str(metadata.get("model_type", "cnn_lstm")).lower()
1244
+ MODEL_FORMAT = str(
1245
+ metadata.get("model_format", "joblib" if MODEL_TYPE == "svm" else "keras")
1246
+ ).lower()
1247
+
1248
+
1249
+ apply_metadata(METADATA)
1250
+
1251
+
1252
+ def sync_label_classes_from_model(model: Optional[object]) -> None:
1253
+ global LABEL_CLASSES
1254
+ if model is None:
1255
+ return
1256
+ if hasattr(model, "classes_"):
1257
+ LABEL_CLASSES = [str(label) for label in getattr(model, "classes_")]
1258
+ elif not LABEL_CLASSES and hasattr(model, "output_shape"):
1259
+ LABEL_CLASSES = [str(i) for i in range(int(model.output_shape[-1]))]
1260
+
1261
+
1262
+ # Load model and scaler with error handling
1263
+ print("Loading model and scaler...")
1264
+ try:
1265
+ MODEL = try_load_model(MODEL_PATH, MODEL_TYPE, MODEL_FORMAT)
1266
+ print(f"Model loaded: {MODEL is not None}")
1267
+ except Exception as e:
1268
+ print(f"Model loading failed: {e}")
1269
+ MODEL = None
1270
+
1271
+ try:
1272
+ SCALER = try_load_scaler(SCALER_PATH)
1273
+ print(f"Scaler loaded: {SCALER is not None}")
1274
+ except Exception as e:
1275
+ print(f"Scaler loading failed: {e}")
1276
+ SCALER = None
1277
+
1278
+ try:
1279
+ sync_label_classes_from_model(MODEL)
1280
+ print("Label classes synchronized")
1281
+ except Exception as e:
1282
+ print(f"Label sync failed: {e}")
1283
+
1284
+ print("Application initialization completed.")
1285
+ print(
1286
+ f"Ready to start Gradio interface. Model available: {MODEL is not None}, Scaler available: {SCALER is not None}"
1287
+ )
1288
+
1289
+
1290
+ def refresh_artifacts(model_path: Path, scaler_path: Path, metadata_path: Path) -> None:
1291
+ global MODEL_PATH, SCALER_PATH, METADATA_PATH, MODEL, SCALER, METADATA
1292
+ MODEL_PATH = model_path
1293
+ SCALER_PATH = scaler_path
1294
+ METADATA_PATH = metadata_path
1295
+ METADATA = load_metadata(metadata_path)
1296
+ apply_metadata(METADATA)
1297
+ MODEL = try_load_model(model_path, MODEL_TYPE, MODEL_FORMAT)
1298
+ SCALER = try_load_scaler(scaler_path)
1299
+ sync_label_classes_from_model(MODEL)
1300
+
1301
+
1302
+ # --------------------------------------------------------------------------------------
1303
+ # Pre-processing helpers
1304
+ # --------------------------------------------------------------------------------------
1305
+
1306
+
1307
+ def ensure_ready():
1308
+ if MODEL is None or SCALER is None:
1309
+ raise RuntimeError(
1310
+ "The model and feature scaler are not available. Upload the trained model "
1311
+ "(for example `pmu_cnn_lstm_model.keras`, `pmu_tcn_model.keras`, or `pmu_svm_model.joblib`), "
1312
+ "the feature scaler (`pmu_feature_scaler.pkl`), and the metadata JSON (`pmu_metadata.json`) to the Space root "
1313
+ "or configure the Hugging Face Hub environment variables so the artifacts can be downloaded "
1314
+ "automatically."
1315
+ )
1316
+
1317
+
1318
+ def parse_text_features(text: str) -> np.ndarray:
1319
+ cleaned = re.sub(r"[;\n\t]+", ",", text.strip())
1320
+ arr = np.fromstring(cleaned, sep=",")
1321
+ if arr.size == 0:
1322
+ raise ValueError(
1323
+ "No feature values were parsed. Please enter comma-separated numbers."
1324
+ )
1325
+ return arr.astype(np.float32)
1326
+
1327
+
1328
+ def apply_scaler(sequences: np.ndarray) -> np.ndarray:
1329
+ if SCALER is None:
1330
+ return sequences
1331
+ shape = sequences.shape
1332
+ flattened = sequences.reshape(-1, shape[-1])
1333
+ scaled = SCALER.transform(flattened)
1334
+ return scaled.reshape(shape)
1335
+
1336
+
1337
+ def make_sliding_windows(
1338
+ data: np.ndarray, sequence_length: int, stride: int
1339
+ ) -> np.ndarray:
1340
+ if data.shape[0] < sequence_length:
1341
+ raise ValueError(
1342
+ f"The dataset contains {data.shape[0]} rows which is less than the requested sequence "
1343
+ f"length {sequence_length}. Provide more samples or reduce the sequence length."
1344
+ )
1345
+ windows = [
1346
+ data[start : start + sequence_length]
1347
+ for start in range(0, data.shape[0] - sequence_length + 1, stride)
1348
+ ]
1349
+ return np.stack(windows)
1350
+
1351
+
1352
+ def dataframe_to_sequences(
1353
+ df: pd.DataFrame,
1354
+ *,
1355
+ sequence_length: int,
1356
+ stride: int,
1357
+ feature_columns: Sequence[str],
1358
+ drop_label: bool = True,
1359
+ ) -> np.ndarray:
1360
+ work_df = df.copy()
1361
+ if drop_label and LABEL_COLUMN in work_df.columns:
1362
+ work_df = work_df.drop(columns=[LABEL_COLUMN])
1363
+ if "Timestamp" in work_df.columns:
1364
+ work_df = work_df.sort_values("Timestamp")
1365
+
1366
+ available_cols = [c for c in feature_columns if c in work_df.columns]
1367
+ n_features = len(feature_columns)
1368
+ if available_cols and len(available_cols) == n_features:
1369
+ array = work_df[available_cols].astype(np.float32).to_numpy()
1370
+ return make_sliding_windows(array, sequence_length, stride)
1371
+
1372
+ numeric_df = work_df.select_dtypes(include=[np.number])
1373
+ array = numeric_df.astype(np.float32).to_numpy()
1374
+ if array.shape[1] == n_features * sequence_length:
1375
+ return array.reshape(array.shape[0], sequence_length, n_features)
1376
+ if sequence_length == 1 and array.shape[1] == n_features:
1377
+ return array.reshape(array.shape[0], 1, n_features)
1378
+ raise ValueError(
1379
+ "CSV columns do not match the expected feature layout. Include the full PMU feature set "
1380
+ "or provide pre-shaped sliding window data."
1381
+ )
1382
+
1383
+
1384
+ def label_name(index: int) -> str:
1385
+ if 0 <= index < len(LABEL_CLASSES):
1386
+ return str(LABEL_CLASSES[index])
1387
+ return f"class_{index}"
1388
+
1389
+
1390
+ def format_predictions(probabilities: np.ndarray) -> pd.DataFrame:
1391
+ rows: List[Dict[str, object]] = []
1392
+ order = np.argsort(probabilities, axis=1)[:, ::-1]
1393
+ for idx, (prob_row, ranking) in enumerate(zip(probabilities, order)):
1394
+ top_idx = int(ranking[0])
1395
+ top_label = label_name(top_idx)
1396
+ top_conf = float(prob_row[top_idx])
1397
+ top3 = [f"{label_name(i)} ({prob_row[i]*100:.2f}%)" for i in ranking[:3]]
1398
+ rows.append(
1399
+ {
1400
+ "window": idx,
1401
+ "predicted_label": top_label,
1402
+ "confidence": round(top_conf, 4),
1403
+ "top3": " | ".join(top3),
1404
+ }
1405
+ )
1406
+ return pd.DataFrame(rows)
1407
+
1408
+
1409
+ def probabilities_to_json(probabilities: np.ndarray) -> List[Dict[str, object]]:
1410
+ payload: List[Dict[str, object]] = []
1411
+ for idx, prob_row in enumerate(probabilities):
1412
+ payload.append(
1413
+ {
1414
+ "window": int(idx),
1415
+ "probabilities": {
1416
+ label_name(i): float(prob_row[i]) for i in range(prob_row.shape[0])
1417
+ },
1418
+ }
1419
+ )
1420
+ return payload
1421
+
1422
+
1423
+ def predict_sequences(
1424
+ sequences: np.ndarray,
1425
+ ) -> Tuple[str, pd.DataFrame, List[Dict[str, object]]]:
1426
+ ensure_ready()
1427
+ sequences = apply_scaler(sequences.astype(np.float32))
1428
+ if MODEL_TYPE == "svm":
1429
+ flattened = sequences.reshape(sequences.shape[0], -1)
1430
+ if hasattr(MODEL, "predict_proba"):
1431
+ probs = MODEL.predict_proba(flattened)
1432
+ else:
1433
+ raise RuntimeError(
1434
+ "Loaded SVM model does not expose predict_proba. Retrain with probability=True."
1435
+ )
1436
+ else:
1437
+ probs = MODEL.predict(sequences, verbose=0)
1438
+ table = format_predictions(probs)
1439
+ json_probs = probabilities_to_json(probs)
1440
+ architecture = MODEL_TYPE.replace("_", "-").upper()
1441
+ status = f"Generated {len(sequences)} windows. {architecture} model output dimension: {probs.shape[1]}."
1442
+ return status, table, json_probs
1443
+
1444
+
1445
+ def predict_from_text(
1446
+ text: str, sequence_length: int
1447
+ ) -> Tuple[str, pd.DataFrame, List[Dict[str, object]]]:
1448
+ arr = parse_text_features(text)
1449
+ n_features = len(FEATURE_COLUMNS)
1450
+ if arr.size % n_features != 0:
1451
+ raise ValueError(
1452
+ f"The number of values ({arr.size}) is not a multiple of the feature dimension "
1453
+ f"({n_features}). Provide values in groups of {n_features}."
1454
+ )
1455
+ timesteps = arr.size // n_features
1456
+ if timesteps != sequence_length:
1457
+ raise ValueError(
1458
+ f"Detected {timesteps} timesteps which does not match the configured sequence length "
1459
+ f"({sequence_length})."
1460
+ )
1461
+ sequences = arr.reshape(1, sequence_length, n_features)
1462
+ status, table, probs = predict_sequences(sequences)
1463
+ status = f"Single window prediction complete. {status}"
1464
+ return status, table, probs
1465
+
1466
+
1467
+ def predict_from_csv(
1468
+ file_obj, sequence_length: int, stride: int
1469
+ ) -> Tuple[str, pd.DataFrame, List[Dict[str, object]]]:
1470
+ df = load_measurement_csv(file_obj.name)
1471
+ sequences = dataframe_to_sequences(
1472
+ df,
1473
+ sequence_length=sequence_length,
1474
+ stride=stride,
1475
+ feature_columns=FEATURE_COLUMNS,
1476
+ )
1477
+ status, table, probs = predict_sequences(sequences)
1478
+ status = f"CSV processed successfully. Generated {len(sequences)} windows. {status}"
1479
+ return status, table, probs
1480
+
1481
+
1482
+ # --------------------------------------------------------------------------------------
1483
+ # Training helpers
1484
+ # --------------------------------------------------------------------------------------
1485
+
1486
+
1487
+ def classification_report_to_dataframe(report: Dict[str, Any]) -> pd.DataFrame:
1488
+ rows: List[Dict[str, Any]] = []
1489
+ for label, metrics in report.items():
1490
+ if isinstance(metrics, dict):
1491
+ row = {"label": label}
1492
+ for key, value in metrics.items():
1493
+ if key == "support":
1494
+ row[key] = int(value)
1495
+ else:
1496
+ row[key] = round(float(value), 4)
1497
+ rows.append(row)
1498
+ else:
1499
+ rows.append({"label": label, "accuracy": round(float(metrics), 4)})
1500
+ return pd.DataFrame(rows)
1501
+
1502
+
1503
+ def confusion_matrix_to_dataframe(
1504
+ confusion: Sequence[Sequence[float]], labels: Sequence[str]
1505
+ ) -> pd.DataFrame:
1506
+ if not confusion:
1507
+ return pd.DataFrame()
1508
+ df = pd.DataFrame(confusion, index=list(labels), columns=list(labels))
1509
+ df.index.name = "True Label"
1510
+ df.columns.name = "Predicted Label"
1511
+ return df
1512
+
1513
+
1514
+ # --------------------------------------------------------------------------------------
1515
+ # Gradio interface
1516
+ # --------------------------------------------------------------------------------------
1517
+
1518
+
1519
+ def build_interface() -> gr.Blocks:
1520
+ theme = gr.themes.Soft(
1521
+ primary_hue="sky", secondary_hue="blue", neutral_hue="gray"
1522
+ ).set(
1523
+ body_background_fill="#1f1f1f",
1524
+ body_text_color="#f5f5f5",
1525
+ block_background_fill="#262626",
1526
+ block_border_color="#333333",
1527
+ button_primary_background_fill="#5ac8fa",
1528
+ button_primary_background_fill_hover="#48b5eb",
1529
+ button_primary_border_color="#38bdf8",
1530
+ button_primary_text_color="#0f172a",
1531
+ button_secondary_background_fill="#3f3f46",
1532
+ button_secondary_text_color="#f5f5f5",
1533
+ )
1534
+
1535
+ def _normalise_directory_string(value: Optional[Union[str, Path]]) -> str:
1536
+ if value is None:
1537
+ return ""
1538
+ path = Path(value).expanduser()
1539
+ try:
1540
+ return str(path.resolve())
1541
+ except Exception:
1542
+ return str(path)
1543
+
1544
+ with gr.Blocks(
1545
+ title="Fault Classification - PMU Data", theme=theme, css=APP_CSS
1546
+ ) as demo:
1547
+ gr.Markdown("# Fault Classification for PMU & PV Data")
1548
+ gr.Markdown(
1549
+ "🖥️ TensorFlow is locked to CPU execution so the Space can run without CUDA drivers."
1550
+ )
1551
+ if MODEL is None or SCALER is None:
1552
+ gr.Markdown(
1553
+ "⚠️ **Artifacts Missing** — Upload `pmu_cnn_lstm_model.keras`, "
1554
+ "`pmu_feature_scaler.pkl`, and `pmu_metadata.json` to enable inference, "
1555
+ "or configure the Hugging Face Hub environment variables so they can be downloaded."
1556
+ )
1557
+ else:
1558
+ class_count = len(LABEL_CLASSES) if LABEL_CLASSES else "unknown"
1559
+ gr.Markdown(
1560
+ f"Loaded a **{MODEL_TYPE.upper()}** model ({MODEL_FORMAT.upper()}) with "
1561
+ f"{len(FEATURE_COLUMNS)} features, sequence length **{SEQUENCE_LENGTH}**, and "
1562
+ f"{class_count} target classes. Use the tabs below to run inference or fine-tune "
1563
+ "the model with your own CSV files."
1564
+ )
1565
+
1566
+ with gr.Accordion("Feature Reference", open=False):
1567
+ gr.Markdown(
1568
+ f"Each time window expects **{len(FEATURE_COLUMNS)} features** ordered as follows:\n"
1569
+ + "\n".join(f"- {name}" for name in FEATURE_COLUMNS)
1570
+ )
1571
+ gr.Markdown(
1572
+ f"Default training parameters: **sequence length = {SEQUENCE_LENGTH}**, "
1573
+ f"**stride = {DEFAULT_WINDOW_STRIDE}**. Adjust them in the tabs as needed."
1574
+ )
1575
+
1576
+ with gr.Tabs():
1577
+ with gr.Tab("Overview"):
1578
+ gr.Markdown(PROJECT_OVERVIEW_MD)
1579
+ with gr.Tab("Inference"):
1580
+ gr.Markdown("## Run Inference")
1581
+ with gr.Row():
1582
+ file_in = gr.File(label="Upload PMU CSV", file_types=[".csv"])
1583
+ text_in = gr.Textbox(
1584
+ lines=4,
1585
+ label="Or paste a single window (comma separated)",
1586
+ placeholder="49.97772,1.215825E-38,...",
1587
+ )
1588
+
1589
+ with gr.Row():
1590
+ sequence_length_input = gr.Slider(
1591
+ minimum=1,
1592
+ maximum=max(1, SEQUENCE_LENGTH * 2),
1593
+ step=1,
1594
+ value=SEQUENCE_LENGTH,
1595
+ label="Sequence length (timesteps)",
1596
+ )
1597
+ stride_input = gr.Slider(
1598
+ minimum=1,
1599
+ maximum=max(1, SEQUENCE_LENGTH),
1600
+ step=1,
1601
+ value=max(1, DEFAULT_WINDOW_STRIDE),
1602
+ label="CSV window stride",
1603
+ )
1604
+
1605
+ predict_btn = gr.Button("🚀 Run Inference", variant="primary")
1606
+ status_out = gr.Textbox(label="Status", interactive=False)
1607
+ table_out = gr.Dataframe(
1608
+ headers=["window", "predicted_label", "confidence", "top3"],
1609
+ label="Predictions",
1610
+ interactive=False,
1611
+ )
1612
+ probs_out = gr.JSON(label="Per-window probabilities")
1613
+
1614
+ def _run_prediction(file_obj, text, sequence_length, stride):
1615
+ sequence_length = int(sequence_length)
1616
+ stride = int(stride)
1617
+ try:
1618
+ if file_obj is not None:
1619
+ return predict_from_csv(file_obj, sequence_length, stride)
1620
+ if text and text.strip():
1621
+ return predict_from_text(text, sequence_length)
1622
+ return (
1623
+ "Please upload a CSV file or provide feature values.",
1624
+ pd.DataFrame(),
1625
+ [],
1626
+ )
1627
+ except Exception as exc:
1628
+ return f"Prediction failed: {exc}", pd.DataFrame(), []
1629
+
1630
+ predict_btn.click(
1631
+ _run_prediction,
1632
+ inputs=[file_in, text_in, sequence_length_input, stride_input],
1633
+ outputs=[status_out, table_out, probs_out],
1634
+ concurrency_limit=EVENT_CONCURRENCY_LIMIT,
1635
+ )
1636
+
1637
+ with gr.Tab("Training"):
1638
+ gr.Markdown("## Train or Fine-tune the Model")
1639
+ gr.Markdown(
1640
+ "Training data is automatically downloaded from the database. "
1641
+ "Refresh the cache if new files are added upstream."
1642
+ )
1643
+
1644
+ training_files_state = gr.State([])
1645
+ with gr.Row():
1646
+ with gr.Column(scale=3):
1647
+ training_files_summary = gr.Textbox(
1648
+ label="Database training CSVs",
1649
+ value="Training dataset not loaded yet.",
1650
+ lines=4,
1651
+ interactive=False,
1652
+ elem_id="training-files-summary",
1653
+ )
1654
+ with gr.Column(scale=2, min_width=240):
1655
+ dataset_info = gr.Markdown(
1656
+ "No local database CSVs downloaded yet.",
1657
+ )
1658
+ dataset_refresh = gr.Button(
1659
+ "🔄 Reload dataset from database",
1660
+ variant="secondary",
1661
+ )
1662
+ clear_cache_button = gr.Button(
1663
+ "🧹 Clear downloaded cache",
1664
+ variant="secondary",
1665
+ )
1666
+
1667
+ with gr.Accordion("📂 DataBaseBrowser", open=False):
1668
+ gr.Markdown(
1669
+ "Browse the upstream database by date and download only the CSVs you need."
1670
+ )
1671
+ with gr.Row(elem_id="date-browser-row"):
1672
+ with gr.Column(scale=1, elem_classes=["date-browser-column"]):
1673
+ year_selector = gr.Dropdown(label="Year", choices=[])
1674
+ year_download_button = gr.Button(
1675
+ "⬇️ Download year CSVs", variant="secondary"
1676
+ )
1677
+ with gr.Column(scale=1, elem_classes=["date-browser-column"]):
1678
+ month_selector = gr.Dropdown(label="Month", choices=[])
1679
+ month_download_button = gr.Button(
1680
+ "⬇️ Download month CSVs", variant="secondary"
1681
+ )
1682
+ with gr.Column(scale=1, elem_classes=["date-browser-column"]):
1683
+ day_selector = gr.Dropdown(label="Day", choices=[])
1684
+ day_download_button = gr.Button(
1685
+ "⬇️ Download day CSVs", variant="secondary"
1686
+ )
1687
+ with gr.Column(elem_id="available-files-section"):
1688
+ available_files = gr.CheckboxGroup(
1689
+ label="Available CSV files",
1690
+ choices=[],
1691
+ value=[],
1692
+ elem_id="available-files-grid",
1693
+ )
1694
+ download_button = gr.Button(
1695
+ "⬇️ Download selected CSVs",
1696
+ variant="secondary",
1697
+ elem_id="download-selected-button",
1698
+ )
1699
+ repo_status = gr.Markdown(
1700
+ "Click 'Reload dataset from database' to fetch the directory tree."
1701
+ )
1702
+
1703
+ with gr.Row():
1704
+ label_input = gr.Dropdown(
1705
+ value=LABEL_COLUMN,
1706
+ choices=[LABEL_COLUMN],
1707
+ allow_custom_value=True,
1708
+ label="Label column name",
1709
+ )
1710
+ model_selector = gr.Radio(
1711
+ choices=["CNN-LSTM", "TCN", "SVM"],
1712
+ value=(
1713
+ "TCN"
1714
+ if MODEL_TYPE == "tcn"
1715
+ else ("SVM" if MODEL_TYPE == "svm" else "CNN-LSTM")
1716
+ ),
1717
+ label="Model architecture",
1718
+ )
1719
+ sequence_length_train = gr.Slider(
1720
+ minimum=4,
1721
+ maximum=max(32, SEQUENCE_LENGTH * 2),
1722
+ step=1,
1723
+ value=SEQUENCE_LENGTH,
1724
+ label="Sequence length",
1725
+ )
1726
+ stride_train = gr.Slider(
1727
+ minimum=1,
1728
+ maximum=max(32, SEQUENCE_LENGTH * 2),
1729
+ step=1,
1730
+ value=max(1, DEFAULT_WINDOW_STRIDE),
1731
+ label="Stride",
1732
+ )
1733
+
1734
+ model_default = MODEL_FILENAME_BY_TYPE.get(
1735
+ MODEL_TYPE, Path(LOCAL_MODEL_FILE).name
1736
+ )
1737
+
1738
+ with gr.Row():
1739
+ validation_train = gr.Slider(
1740
+ minimum=0.05,
1741
+ maximum=0.4,
1742
+ step=0.05,
1743
+ value=0.2,
1744
+ label="Validation split",
1745
+ )
1746
+ batch_train = gr.Slider(
1747
+ minimum=32,
1748
+ maximum=512,
1749
+ step=32,
1750
+ value=128,
1751
+ label="Batch size",
1752
+ )
1753
+ epochs_train = gr.Slider(
1754
+ minimum=5,
1755
+ maximum=100,
1756
+ step=5,
1757
+ value=50,
1758
+ label="Epochs",
1759
+ )
1760
+
1761
+ directory_choices, directory_default = gather_directory_choices(
1762
+ str(MODEL_OUTPUT_DIR)
1763
+ )
1764
+ artifact_choices, default_artifact = gather_artifact_choices(
1765
+ directory_default
1766
+ )
1767
+
1768
+ with gr.Row():
1769
+ output_directory = gr.Dropdown(
1770
+ value=directory_default,
1771
+ label="Output directory",
1772
+ choices=directory_choices,
1773
+ allow_custom_value=True,
1774
+ )
1775
+ model_name = gr.Textbox(
1776
+ value=model_default,
1777
+ label="Model output filename",
1778
+ )
1779
+ scaler_name = gr.Textbox(
1780
+ value=Path(LOCAL_SCALER_FILE).name,
1781
+ label="Scaler output filename",
1782
+ )
1783
+ metadata_name = gr.Textbox(
1784
+ value=Path(LOCAL_METADATA_FILE).name,
1785
+ label="Metadata output filename",
1786
+ )
1787
+
1788
+ with gr.Row():
1789
+ artifact_browser = gr.Dropdown(
1790
+ label="Saved artifacts in directory",
1791
+ choices=artifact_choices,
1792
+ value=default_artifact,
1793
+ )
1794
+ artifact_download_button = gr.DownloadButton(
1795
+ "⬇️ Download selected artifact",
1796
+ value=default_artifact,
1797
+ visible=bool(default_artifact),
1798
+ variant="secondary",
1799
+ )
1800
+
1801
+ def on_output_directory_change(selected_dir, current_selection):
1802
+ choices, normalised = gather_directory_choices(selected_dir)
1803
+ artifact_options, selected = gather_artifact_choices(
1804
+ normalised, current_selection
1805
+ )
1806
+ return (
1807
+ gr.update(choices=choices, value=normalised),
1808
+ gr.update(choices=artifact_options, value=selected),
1809
+ download_button_state(selected),
1810
+ )
1811
+
1812
+ def on_artifact_change(selected_path):
1813
+ return download_button_state(selected_path)
1814
+
1815
+ output_directory.change(
1816
+ on_output_directory_change,
1817
+ inputs=[output_directory, artifact_browser],
1818
+ outputs=[
1819
+ output_directory,
1820
+ artifact_browser,
1821
+ artifact_download_button,
1822
+ ],
1823
+ concurrency_limit=EVENT_CONCURRENCY_LIMIT,
1824
+ )
1825
+
1826
+ artifact_browser.change(
1827
+ on_artifact_change,
1828
+ inputs=[artifact_browser],
1829
+ outputs=[artifact_download_button],
1830
+ concurrency_limit=EVENT_CONCURRENCY_LIMIT,
1831
+ )
1832
+
1833
+ with gr.Row(elem_id="artifact-download-row"):
1834
+ model_download_button = gr.DownloadButton(
1835
+ "⬇️ Download model file",
1836
+ value=None,
1837
+ visible=False,
1838
+ elem_classes=["artifact-download-button"],
1839
+ )
1840
+ scaler_download_button = gr.DownloadButton(
1841
+ "⬇️ Download scaler file",
1842
+ value=None,
1843
+ visible=False,
1844
+ elem_classes=["artifact-download-button"],
1845
+ )
1846
+ metadata_download_button = gr.DownloadButton(
1847
+ "⬇️ Download metadata file",
1848
+ value=None,
1849
+ visible=False,
1850
+ elem_classes=["artifact-download-button"],
1851
+ )
1852
+ tensorboard_download_button = gr.DownloadButton(
1853
+ "⬇️ Download TensorBoard logs",
1854
+ value=None,
1855
+ visible=False,
1856
+ elem_classes=["artifact-download-button"],
1857
+ )
1858
+
1859
+ model_download_button.file_name = Path(LOCAL_MODEL_FILE).name
1860
+ scaler_download_button.file_name = Path(LOCAL_SCALER_FILE).name
1861
+ metadata_download_button.file_name = Path(LOCAL_METADATA_FILE).name
1862
+ tensorboard_download_button.file_name = "tensorboard_logs.zip"
1863
+
1864
+ tensorboard_toggle = gr.Checkbox(
1865
+ value=True,
1866
+ label="Enable TensorBoard logging (creates downloadable archive)",
1867
+ )
1868
+
1869
+ def _suggest_model_filename(choice: str, current_value: str):
1870
+ choice_key = (choice or "cnn_lstm").lower().replace("-", "_")
1871
+ suggested = MODEL_FILENAME_BY_TYPE.get(
1872
+ choice_key, Path(LOCAL_MODEL_FILE).name
1873
+ )
1874
+ known_defaults = set(MODEL_FILENAME_BY_TYPE.values())
1875
+ current_name = Path(current_value).name if current_value else ""
1876
+ if current_name and current_name not in known_defaults:
1877
+ return gr.update()
1878
+ return gr.update(value=suggested)
1879
+
1880
+ model_selector.change(
1881
+ _suggest_model_filename,
1882
+ inputs=[model_selector, model_name],
1883
+ outputs=model_name,
1884
+ )
1885
+
1886
+ with gr.Row():
1887
+ train_button = gr.Button("🛠️ Start Training", variant="primary")
1888
+ progress_button = gr.Button(
1889
+ "📊 Check Progress", variant="secondary"
1890
+ )
1891
+
1892
+ # Training status display
1893
+ training_status = gr.Textbox(label="Training Status", interactive=False)
1894
+ report_output = gr.Dataframe(
1895
+ label="Classification report", interactive=False
1896
+ )
1897
+ history_output = gr.JSON(label="Training history")
1898
+ confusion_output = gr.Dataframe(
1899
+ label="Confusion matrix", interactive=False
1900
+ )
1901
+
1902
+ # Message area at the bottom for progress updates
1903
+ with gr.Accordion("📋 Progress Messages", open=True):
1904
+ progress_messages = gr.Textbox(
1905
+ label="Training Messages",
1906
+ lines=8,
1907
+ max_lines=20,
1908
+ interactive=False,
1909
+ autoscroll=True,
1910
+ placeholder="Click 'Check Progress' to see training updates...",
1911
+ )
1912
+ with gr.Row():
1913
+ gr.Button("🗑️ Clear Messages", variant="secondary").click(
1914
+ lambda: "", outputs=[progress_messages]
1915
+ )
1916
+
1917
+ def _run_training(
1918
+ file_paths,
1919
+ label_column,
1920
+ model_choice,
1921
+ sequence_length,
1922
+ stride,
1923
+ validation_split,
1924
+ batch_size,
1925
+ epochs,
1926
+ output_dir,
1927
+ model_filename,
1928
+ scaler_filename,
1929
+ metadata_filename,
1930
+ enable_tensorboard,
1931
+ ):
1932
+ base_dir = normalise_output_directory(output_dir)
1933
+ try:
1934
+ base_dir.mkdir(parents=True, exist_ok=True)
1935
+
1936
+ model_path = resolve_output_path(
1937
+ base_dir,
1938
+ model_filename,
1939
+ Path(LOCAL_MODEL_FILE).name,
1940
+ )
1941
+ scaler_path = resolve_output_path(
1942
+ base_dir,
1943
+ scaler_filename,
1944
+ Path(LOCAL_SCALER_FILE).name,
1945
+ )
1946
+ metadata_path = resolve_output_path(
1947
+ base_dir,
1948
+ metadata_filename,
1949
+ Path(LOCAL_METADATA_FILE).name,
1950
+ )
1951
+
1952
+ model_path.parent.mkdir(parents=True, exist_ok=True)
1953
+ scaler_path.parent.mkdir(parents=True, exist_ok=True)
1954
+ metadata_path.parent.mkdir(parents=True, exist_ok=True)
1955
+
1956
+ # Create status file path for progress tracking
1957
+ status_file = model_path.parent / "training_status.txt"
1958
+
1959
+ # Initialize status
1960
+ with open(status_file, "w") as f:
1961
+ f.write("Starting training setup...")
1962
+
1963
+ if not file_paths:
1964
+ raise ValueError(
1965
+ "No training CSVs were found in the database cache. "
1966
+ "Use 'Reload dataset from database' and try again."
1967
+ )
1968
+
1969
+ with open(status_file, "w") as f:
1970
+ f.write("Loading and validating CSV files...")
1971
+
1972
+ available_paths = [
1973
+ path for path in file_paths if Path(path).exists()
1974
+ ]
1975
+ missing_paths = [
1976
+ Path(path).name
1977
+ for path in file_paths
1978
+ if not Path(path).exists()
1979
+ ]
1980
+ if not available_paths:
1981
+ raise ValueError(
1982
+ "Database training dataset is unavailable. Reload the dataset and retry."
1983
+ )
1984
+
1985
+ dfs = [load_measurement_csv(path) for path in available_paths]
1986
+ combined = pd.concat(dfs, ignore_index=True)
1987
+
1988
+ # Validate data size and provide recommendations
1989
+ total_samples = len(combined)
1990
+ if total_samples < 100:
1991
+ print(
1992
+ f"Warning: Only {total_samples} samples. Recommend at least 1000 for good results."
1993
+ )
1994
+ print(
1995
+ "Automatically switching to SVM for small dataset compatibility."
1996
+ )
1997
+ if model_choice in ["cnn_lstm", "tcn"]:
1998
+ model_choice = "svm"
1999
+ print(
2000
+ f"Model type changed to SVM for better small dataset performance."
2001
+ )
2002
+ if total_samples < 10:
2003
+ raise ValueError(
2004
+ f"Insufficient data: {total_samples} samples. Need at least 10 samples for training."
2005
+ )
2006
+
2007
+ label_column = (label_column or LABEL_COLUMN).strip()
2008
+ if not label_column:
2009
+ raise ValueError("Label column name cannot be empty.")
2010
+
2011
+ model_choice = (
2012
+ (model_choice or "CNN-LSTM").lower().replace("-", "_")
2013
+ )
2014
+ if model_choice not in {"cnn_lstm", "tcn", "svm"}:
2015
+ raise ValueError(
2016
+ "Select CNN-LSTM, TCN, or SVM for the model architecture."
2017
+ )
2018
+
2019
+ with open(status_file, "w") as f:
2020
+ f.write(
2021
+ f"Starting {model_choice.upper()} training with {len(combined)} samples..."
2022
+ )
2023
+
2024
+ # Start training
2025
+ result = train_from_dataframe(
2026
+ combined,
2027
+ label_column=label_column,
2028
+ feature_columns=None,
2029
+ sequence_length=int(sequence_length),
2030
+ stride=int(stride),
2031
+ validation_split=float(validation_split),
2032
+ batch_size=int(batch_size),
2033
+ epochs=int(epochs),
2034
+ model_type=model_choice,
2035
+ model_path=model_path,
2036
+ scaler_path=scaler_path,
2037
+ metadata_path=metadata_path,
2038
+ enable_tensorboard=bool(enable_tensorboard),
2039
+ )
2040
+
2041
+ refresh_artifacts(
2042
+ Path(result["model_path"]),
2043
+ Path(result["scaler_path"]),
2044
+ Path(result["metadata_path"]),
2045
+ )
2046
+
2047
+ report_df = classification_report_to_dataframe(
2048
+ result["classification_report"]
2049
+ )
2050
+ confusion_df = confusion_matrix_to_dataframe(
2051
+ result["confusion_matrix"], result["class_names"]
2052
+ )
2053
+ tensorboard_dir = result.get("tensorboard_log_dir")
2054
+ tensorboard_zip = result.get("tensorboard_zip_path")
2055
+
2056
+ architecture = result["model_type"].replace("_", "-").upper()
2057
+ status = (
2058
+ f"Training complete using a {architecture} architecture. "
2059
+ f"{result['num_sequences']} windows derived from "
2060
+ f"{result['num_samples']} rows across {len(available_paths)} file(s)."
2061
+ f" Artifacts saved to:"
2062
+ f"\n• Model: {result['model_path']}\n"
2063
+ f"• Scaler: {result['scaler_path']}\n"
2064
+ f"• Metadata: {result['metadata_path']}"
2065
+ )
2066
+
2067
+ status += f"\nLabel column used: {result.get('label_column', label_column)}"
2068
+
2069
+ if tensorboard_dir:
2070
+ status += (
2071
+ f"\nTensorBoard logs directory: {tensorboard_dir}"
2072
+ f'\nRun `tensorboard --logdir "{tensorboard_dir}"` to inspect the training curves.'
2073
+ "\nDownload the archive below to explore the run offline."
2074
+ )
2075
+
2076
+ if missing_paths:
2077
+ skipped = ", ".join(missing_paths)
2078
+ status = f"⚠️ Skipped missing files: {skipped}\n" + status
2079
+
2080
+ artifact_choices, selected_artifact = gather_artifact_choices(
2081
+ str(base_dir), result["model_path"]
2082
+ )
2083
+
2084
+ return (
2085
+ status,
2086
+ report_df,
2087
+ result["history"],
2088
+ confusion_df,
2089
+ download_button_state(result["model_path"]),
2090
+ download_button_state(result["scaler_path"]),
2091
+ download_button_state(result["metadata_path"]),
2092
+ download_button_state(tensorboard_zip),
2093
+ gr.update(value=result.get("label_column", label_column)),
2094
+ gr.update(
2095
+ choices=artifact_choices, value=selected_artifact
2096
+ ),
2097
+ download_button_state(selected_artifact),
2098
+ )
2099
+ except Exception as exc:
2100
+ artifact_choices, selected_artifact = gather_artifact_choices(
2101
+ str(base_dir)
2102
+ )
2103
+ return (
2104
+ f"Training failed: {exc}",
2105
+ pd.DataFrame(),
2106
+ {},
2107
+ pd.DataFrame(),
2108
+ download_button_state(None),
2109
+ download_button_state(None),
2110
+ download_button_state(None),
2111
+ download_button_state(None),
2112
+ gr.update(),
2113
+ gr.update(
2114
+ choices=artifact_choices, value=selected_artifact
2115
+ ),
2116
+ download_button_state(selected_artifact),
2117
+ )
2118
+
2119
+ def _check_progress(output_dir, model_filename, current_messages):
2120
+ """Check training progress by reading status file and accumulate messages."""
2121
+ model_path = resolve_output_path(
2122
+ output_dir, model_filename, Path(LOCAL_MODEL_FILE).name
2123
+ )
2124
+ status_file = model_path.parent / "training_status.txt"
2125
+ status_message = read_training_status(str(status_file))
2126
+
2127
+ # Add timestamp to the message
2128
+ from datetime import datetime
2129
+
2130
+ timestamp = datetime.now().strftime("%H:%M:%S")
2131
+ new_message = f"[{timestamp}] {status_message}"
2132
+
2133
+ # Accumulate messages, keeping last 50 lines to prevent overflow
2134
+ if current_messages:
2135
+ lines = current_messages.split("\n")
2136
+ lines.append(new_message)
2137
+ # Keep only last 50 lines
2138
+ if len(lines) > 50:
2139
+ lines = lines[-50:]
2140
+ accumulated_messages = "\n".join(lines)
2141
+ else:
2142
+ accumulated_messages = new_message
2143
+
2144
+ return accumulated_messages
2145
+
2146
+ train_button.click(
2147
+ _run_training,
2148
+ inputs=[
2149
+ training_files_state,
2150
+ label_input,
2151
+ model_selector,
2152
+ sequence_length_train,
2153
+ stride_train,
2154
+ validation_train,
2155
+ batch_train,
2156
+ epochs_train,
2157
+ output_directory,
2158
+ model_name,
2159
+ scaler_name,
2160
+ metadata_name,
2161
+ tensorboard_toggle,
2162
+ ],
2163
+ outputs=[
2164
+ training_status,
2165
+ report_output,
2166
+ history_output,
2167
+ confusion_output,
2168
+ model_download_button,
2169
+ scaler_download_button,
2170
+ metadata_download_button,
2171
+ tensorboard_download_button,
2172
+ label_input,
2173
+ artifact_browser,
2174
+ artifact_download_button,
2175
+ ],
2176
+ concurrency_limit=EVENT_CONCURRENCY_LIMIT,
2177
+ )
2178
+
2179
+ progress_button.click(
2180
+ _check_progress,
2181
+ inputs=[output_directory, model_name, progress_messages],
2182
+ outputs=[progress_messages],
2183
+ )
2184
+
2185
+ year_selector.change(
2186
+ on_year_change,
2187
+ inputs=[year_selector],
2188
+ outputs=[
2189
+ month_selector,
2190
+ day_selector,
2191
+ available_files,
2192
+ repo_status,
2193
+ ],
2194
+ concurrency_limit=EVENT_CONCURRENCY_LIMIT,
2195
+ )
2196
+
2197
+ month_selector.change(
2198
+ on_month_change,
2199
+ inputs=[year_selector, month_selector],
2200
+ outputs=[day_selector, available_files, repo_status],
2201
+ concurrency_limit=EVENT_CONCURRENCY_LIMIT,
2202
+ )
2203
+
2204
+ day_selector.change(
2205
+ on_day_change,
2206
+ inputs=[year_selector, month_selector, day_selector],
2207
+ outputs=[available_files, repo_status],
2208
+ concurrency_limit=EVENT_CONCURRENCY_LIMIT,
2209
+ )
2210
+
2211
+ download_button.click(
2212
+ download_selected_files,
2213
+ inputs=[
2214
+ year_selector,
2215
+ month_selector,
2216
+ day_selector,
2217
+ available_files,
2218
+ label_input,
2219
+ ],
2220
+ outputs=[
2221
+ training_files_state,
2222
+ training_files_summary,
2223
+ label_input,
2224
+ dataset_info,
2225
+ available_files,
2226
+ repo_status,
2227
+ ],
2228
+ concurrency_limit=EVENT_CONCURRENCY_LIMIT,
2229
+ )
2230
+
2231
+ year_download_button.click(
2232
+ download_year_bundle,
2233
+ inputs=[year_selector, label_input],
2234
+ outputs=[
2235
+ training_files_state,
2236
+ training_files_summary,
2237
+ label_input,
2238
+ dataset_info,
2239
+ available_files,
2240
+ repo_status,
2241
+ ],
2242
+ concurrency_limit=EVENT_CONCURRENCY_LIMIT,
2243
+ )
2244
+
2245
+ month_download_button.click(
2246
+ download_month_bundle,
2247
+ inputs=[year_selector, month_selector, label_input],
2248
+ outputs=[
2249
+ training_files_state,
2250
+ training_files_summary,
2251
+ label_input,
2252
+ dataset_info,
2253
+ available_files,
2254
+ repo_status,
2255
+ ],
2256
+ concurrency_limit=EVENT_CONCURRENCY_LIMIT,
2257
+ )
2258
+
2259
+ day_download_button.click(
2260
+ download_day_bundle,
2261
+ inputs=[year_selector, month_selector, day_selector, label_input],
2262
+ outputs=[
2263
+ training_files_state,
2264
+ training_files_summary,
2265
+ label_input,
2266
+ dataset_info,
2267
+ available_files,
2268
+ repo_status,
2269
+ ],
2270
+ concurrency_limit=EVENT_CONCURRENCY_LIMIT,
2271
+ )
2272
+
2273
+ def _reload_dataset(current_label):
2274
+ local = load_repository_training_files(
2275
+ current_label, force_refresh=True
2276
+ )
2277
+ remote = refresh_remote_browser(force_refresh=True)
2278
+ return (*local, *remote)
2279
+
2280
+ dataset_refresh.click(
2281
+ _reload_dataset,
2282
+ inputs=[label_input],
2283
+ outputs=[
2284
+ training_files_state,
2285
+ training_files_summary,
2286
+ label_input,
2287
+ dataset_info,
2288
+ year_selector,
2289
+ month_selector,
2290
+ day_selector,
2291
+ available_files,
2292
+ repo_status,
2293
+ ],
2294
+ concurrency_limit=EVENT_CONCURRENCY_LIMIT,
2295
+ )
2296
+
2297
+ clear_cache_button.click(
2298
+ clear_downloaded_cache,
2299
+ inputs=[label_input],
2300
+ outputs=[
2301
+ training_files_state,
2302
+ training_files_summary,
2303
+ label_input,
2304
+ dataset_info,
2305
+ year_selector,
2306
+ month_selector,
2307
+ day_selector,
2308
+ available_files,
2309
+ repo_status,
2310
+ ],
2311
+ concurrency_limit=EVENT_CONCURRENCY_LIMIT,
2312
+ )
2313
+
2314
+ def _initialise_dataset():
2315
+ local = load_repository_training_files(
2316
+ LABEL_COLUMN, force_refresh=False
2317
+ )
2318
+ remote = refresh_remote_browser(force_refresh=False)
2319
+ return (*local, *remote)
2320
+
2321
+ demo.load(
2322
+ _initialise_dataset,
2323
+ inputs=None,
2324
+ outputs=[
2325
+ training_files_state,
2326
+ training_files_summary,
2327
+ label_input,
2328
+ dataset_info,
2329
+ year_selector,
2330
+ month_selector,
2331
+ day_selector,
2332
+ available_files,
2333
+ repo_status,
2334
+ ],
2335
+ queue=False,
2336
+ )
2337
+
2338
+ return demo
2339
+
2340
+
2341
+ # --------------------------------------------------------------------------------------
2342
+ # Launch helpers
2343
+ # --------------------------------------------------------------------------------------
2344
+
2345
+
2346
+ def resolve_server_port() -> int:
2347
+ for env_var in ("PORT", "GRADIO_SERVER_PORT"):
2348
+ value = os.environ.get(env_var)
2349
+ if value:
2350
+ try:
2351
+ return int(value)
2352
+ except ValueError:
2353
+ print(f"Ignoring invalid port value from {env_var}: {value}")
2354
+ return 7860
2355
+
2356
+
2357
+ def main():
2358
+ print("Building Gradio interface...")
2359
+ try:
2360
+ demo = build_interface()
2361
+ print("Interface built successfully")
2362
+ except Exception as e:
2363
+ print(f"Failed to build interface: {e}")
2364
+ import traceback
2365
+
2366
+ traceback.print_exc()
2367
+ return
2368
+
2369
+ print("Setting up queue...")
2370
+ try:
2371
+ demo.queue(max_size=QUEUE_MAX_SIZE)
2372
+ print("Queue configured")
2373
+ except Exception as e:
2374
+ print(f"Failed to configure queue: {e}")
2375
+
2376
+ try:
2377
+ port = resolve_server_port()
2378
+ print(f"Launching Gradio app on port {port}")
2379
+ demo.launch(server_name="0.0.0.0", server_port=port, show_error=True)
2380
+ except OSError as exc:
2381
+ print("Failed to launch on requested port:", exc)
2382
+ try:
2383
+ demo.launch(server_name="0.0.0.0", show_error=True)
2384
+ except Exception as e:
2385
+ print(f"Failed to launch completely: {e}")
2386
+ except Exception as e:
2387
+ print(f"Unexpected launch error: {e}")
2388
+ import traceback
2389
+
2390
+ traceback.print_exc()
2391
+
2392
+
2393
+ if __name__ == "__main__":
2394
+ print("=" * 50)
2395
+ print("PMU Fault Classification App Starting")
2396
+ print(f"Python version: {os.sys.version}")
2397
+ print(f"Working directory: {os.getcwd()}")
2398
+ print(f"HUB_REPO: {HUB_REPO}")
2399
+ print(f"Model available: {MODEL is not None}")
2400
+ print(f"Scaler available: {SCALER is not None}")
2401
+ print("=" * 50)
2402
+ main()
.history/app_20251009232247.py ADDED
@@ -0,0 +1,2431 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Gradio front-end for Fault_Classification_PMU_Data models.
2
+
3
+ The application loads a CNN-LSTM model (and accompanying scaler/metadata)
4
+ produced by ``fault_classification_pmu.py`` and exposes a streamlined
5
+ prediction interface optimised for Hugging Face Spaces deployment. It supports
6
+ raw PMU time-series CSV uploads as well as manual comma separated feature
7
+ vectors.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import json
13
+ import os
14
+ import shutil
15
+
16
+ os.environ.setdefault("CUDA_VISIBLE_DEVICES", "-1")
17
+ os.environ.setdefault("TF_CPP_MIN_LOG_LEVEL", "2")
18
+ os.environ.setdefault("TF_ENABLE_ONEDNN_OPTS", "0")
19
+
20
+ import re
21
+ from pathlib import Path
22
+ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
23
+
24
+ import gradio as gr
25
+ import joblib
26
+ import numpy as np
27
+ import pandas as pd
28
+ import requests
29
+ from huggingface_hub import hf_hub_download
30
+ from tensorflow.keras.models import load_model
31
+
32
+ from fault_classification_pmu import (
33
+ DEFAULT_FEATURE_COLUMNS as TRAINING_DEFAULT_FEATURE_COLUMNS,
34
+ LABEL_GUESS_CANDIDATES as TRAINING_LABEL_GUESSES,
35
+ train_from_dataframe,
36
+ )
37
+
38
+ # --------------------------------------------------------------------------------------
39
+ # Configuration
40
+ # --------------------------------------------------------------------------------------
41
+ DEFAULT_FEATURE_COLUMNS: List[str] = list(TRAINING_DEFAULT_FEATURE_COLUMNS)
42
+ DEFAULT_SEQUENCE_LENGTH = 32
43
+ DEFAULT_STRIDE = 4
44
+
45
+ LOCAL_MODEL_FILE = os.environ.get("PMU_MODEL_FILE", "pmu_cnn_lstm_model.keras")
46
+ LOCAL_SCALER_FILE = os.environ.get("PMU_SCALER_FILE", "pmu_feature_scaler.pkl")
47
+ LOCAL_METADATA_FILE = os.environ.get("PMU_METADATA_FILE", "pmu_metadata.json")
48
+
49
+ MODEL_OUTPUT_DIR = Path(os.environ.get("PMU_MODEL_DIR", "model")).resolve()
50
+ MODEL_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
51
+
52
+ HUB_REPO = os.environ.get("PMU_HUB_REPO", "")
53
+ HUB_MODEL_FILENAME = os.environ.get("PMU_HUB_MODEL_FILENAME", LOCAL_MODEL_FILE)
54
+ HUB_SCALER_FILENAME = os.environ.get("PMU_HUB_SCALER_FILENAME", LOCAL_SCALER_FILE)
55
+ HUB_METADATA_FILENAME = os.environ.get("PMU_HUB_METADATA_FILENAME", LOCAL_METADATA_FILE)
56
+
57
+ ENV_MODEL_PATH = "PMU_MODEL_PATH"
58
+ ENV_SCALER_PATH = "PMU_SCALER_PATH"
59
+ ENV_METADATA_PATH = "PMU_METADATA_PATH"
60
+
61
+ # --------------------------------------------------------------------------------------
62
+ # Utility functions for loading artifacts
63
+ # --------------------------------------------------------------------------------------
64
+
65
+
66
+ def download_from_hub(filename: str) -> Optional[Path]:
67
+ if not HUB_REPO or not filename:
68
+ return None
69
+ try:
70
+ print(f"Downloading {filename} from {HUB_REPO} ...")
71
+ # Add timeout to prevent hanging
72
+ path = hf_hub_download(repo_id=HUB_REPO, filename=filename)
73
+ print("Downloaded", path)
74
+ return Path(path)
75
+ except Exception as exc: # pragma: no cover - logging convenience
76
+ print("Failed to download", filename, "from", HUB_REPO, ":", exc)
77
+ print("Continuing without pre-trained model...")
78
+ return None
79
+
80
+
81
+ def resolve_artifact(
82
+ local_name: str, env_var: str, hub_filename: str
83
+ ) -> Optional[Path]:
84
+ print(f"Resolving artifact: {local_name}, env: {env_var}, hub: {hub_filename}")
85
+ candidates = [Path(local_name)] if local_name else []
86
+ if local_name:
87
+ candidates.append(MODEL_OUTPUT_DIR / Path(local_name).name)
88
+ env_value = os.environ.get(env_var)
89
+ if env_value:
90
+ candidates.append(Path(env_value))
91
+
92
+ for candidate in candidates:
93
+ if candidate and candidate.exists():
94
+ print(f"Found local artifact: {candidate}")
95
+ return candidate
96
+
97
+ print(f"No local artifacts found, checking hub...")
98
+ # Only try to download if we have a hub repo configured
99
+ if HUB_REPO:
100
+ return download_from_hub(hub_filename)
101
+ else:
102
+ print("No HUB_REPO configured, skipping download")
103
+ return None
104
+
105
+
106
+ def load_metadata(path: Optional[Path]) -> Dict:
107
+ if path and path.exists():
108
+ try:
109
+ return json.loads(path.read_text())
110
+ except Exception as exc: # pragma: no cover - metadata parsing errors
111
+ print("Failed to read metadata", path, exc)
112
+ return {}
113
+
114
+
115
+ def try_load_scaler(path: Optional[Path]):
116
+ if not path:
117
+ return None
118
+ try:
119
+ scaler = joblib.load(path)
120
+ print("Loaded scaler from", path)
121
+ return scaler
122
+ except Exception as exc:
123
+ print("Failed to load scaler", path, exc)
124
+ return None
125
+
126
+
127
+ # Initialize paths with error handling
128
+ print("Starting application initialization...")
129
+ try:
130
+ MODEL_PATH = resolve_artifact(LOCAL_MODEL_FILE, ENV_MODEL_PATH, HUB_MODEL_FILENAME)
131
+ print(f"Model path resolved: {MODEL_PATH}")
132
+ except Exception as e:
133
+ print(f"Model path resolution failed: {e}")
134
+ MODEL_PATH = None
135
+
136
+ try:
137
+ SCALER_PATH = resolve_artifact(
138
+ LOCAL_SCALER_FILE, ENV_SCALER_PATH, HUB_SCALER_FILENAME
139
+ )
140
+ print(f"Scaler path resolved: {SCALER_PATH}")
141
+ except Exception as e:
142
+ print(f"Scaler path resolution failed: {e}")
143
+ SCALER_PATH = None
144
+
145
+ try:
146
+ METADATA_PATH = resolve_artifact(
147
+ LOCAL_METADATA_FILE, ENV_METADATA_PATH, HUB_METADATA_FILENAME
148
+ )
149
+ print(f"Metadata path resolved: {METADATA_PATH}")
150
+ except Exception as e:
151
+ print(f"Metadata path resolution failed: {e}")
152
+ METADATA_PATH = None
153
+
154
+ try:
155
+ METADATA = load_metadata(METADATA_PATH)
156
+ print(f"Metadata loaded: {len(METADATA)} entries")
157
+ except Exception as e:
158
+ print(f"Metadata loading failed: {e}")
159
+ METADATA = {}
160
+
161
+ # Queuing configuration
162
+ QUEUE_MAX_SIZE = 32
163
+ # Apply a small per-event concurrency limit to avoid relying on the deprecated
164
+ # ``concurrency_count`` parameter when enabling Gradio's request queue.
165
+ EVENT_CONCURRENCY_LIMIT = 2
166
+
167
+
168
+ def try_load_model(path: Optional[Path], model_type: str, model_format: str):
169
+ if not path:
170
+ return None
171
+ try:
172
+ if model_type == "svm" or model_format == "joblib":
173
+ model = joblib.load(path)
174
+ else:
175
+ model = load_model(path)
176
+ print("Loaded model from", path)
177
+ return model
178
+ except Exception as exc: # pragma: no cover - runtime diagnostics
179
+ print("Failed to load model", path, exc)
180
+ return None
181
+
182
+
183
+ FEATURE_COLUMNS: List[str] = list(DEFAULT_FEATURE_COLUMNS)
184
+ LABEL_CLASSES: List[str] = []
185
+ LABEL_COLUMN: str = "Fault"
186
+ SEQUENCE_LENGTH: int = DEFAULT_SEQUENCE_LENGTH
187
+ DEFAULT_WINDOW_STRIDE: int = DEFAULT_STRIDE
188
+ MODEL_TYPE: str = "cnn_lstm"
189
+ MODEL_FORMAT: str = "keras"
190
+
191
+
192
+ def _model_output_path(filename: str) -> str:
193
+ return str(MODEL_OUTPUT_DIR / Path(filename).name)
194
+
195
+
196
+ MODEL_FILENAME_BY_TYPE: Dict[str, str] = {
197
+ "cnn_lstm": Path(LOCAL_MODEL_FILE).name,
198
+ "tcn": "pmu_tcn_model.keras",
199
+ "svm": "pmu_svm_model.joblib",
200
+ }
201
+
202
+ REQUIRED_PMU_COLUMNS: Tuple[str, ...] = tuple(DEFAULT_FEATURE_COLUMNS)
203
+ TRAINING_UPLOAD_DIR = Path(
204
+ os.environ.get("PMU_TRAINING_UPLOAD_DIR", "training_uploads")
205
+ )
206
+ TRAINING_UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
207
+
208
+ TRAINING_DATA_REPO = os.environ.get(
209
+ "PMU_TRAINING_DATA_REPO", "VincentCroft/ThesisModelData"
210
+ )
211
+ TRAINING_DATA_BRANCH = os.environ.get("PMU_TRAINING_DATA_BRANCH", "main")
212
+ TRAINING_DATA_DIR = Path(os.environ.get("PMU_TRAINING_DATA_DIR", "training_dataset"))
213
+ TRAINING_DATA_DIR.mkdir(parents=True, exist_ok=True)
214
+
215
+ GITHUB_CONTENT_CACHE: Dict[str, List[Dict[str, Any]]] = {}
216
+
217
+
218
+ APP_CSS = """
219
+ #available-files-section {
220
+ position: relative;
221
+ display: flex;
222
+ flex-direction: column;
223
+ gap: 0.75rem;
224
+ border-radius: 0.75rem;
225
+ isolation: isolate;
226
+ }
227
+
228
+ #available-files-grid {
229
+ position: static;
230
+ overflow: visible;
231
+ }
232
+
233
+ #available-files-grid .form {
234
+ position: static;
235
+ min-height: 16rem;
236
+ }
237
+
238
+ #available-files-grid .wrap {
239
+ display: grid;
240
+ grid-template-columns: repeat(4, minmax(0, 1fr));
241
+ gap: 0.5rem;
242
+ max-height: 24rem;
243
+ min-height: 16rem;
244
+ overflow-y: auto;
245
+ padding-right: 0.25rem;
246
+ }
247
+
248
+ #available-files-grid .wrap > div {
249
+ min-width: 0;
250
+ }
251
+
252
+ #available-files-grid .wrap label {
253
+ margin: 0;
254
+ display: flex;
255
+ align-items: center;
256
+ padding: 0.45rem 0.65rem;
257
+ border-radius: 0.65rem;
258
+ background-color: rgba(255, 255, 255, 0.05);
259
+ border: 1px solid rgba(255, 255, 255, 0.08);
260
+ transition: background-color 0.2s ease, border-color 0.2s ease;
261
+ min-height: 2.5rem;
262
+ }
263
+
264
+ #available-files-grid .wrap label:hover {
265
+ background-color: rgba(90, 200, 250, 0.16);
266
+ border-color: rgba(90, 200, 250, 0.4);
267
+ }
268
+
269
+ #available-files-grid .wrap label span {
270
+ overflow: hidden;
271
+ text-overflow: ellipsis;
272
+ white-space: nowrap;
273
+ }
274
+
275
+ #available-files-section .gradio-loading,
276
+ #available-files-grid .gradio-loading {
277
+ position: absolute;
278
+ top: 0;
279
+ left: 0;
280
+ right: 0;
281
+ bottom: 0;
282
+ width: 100%;
283
+ height: 100%;
284
+ display: flex;
285
+ align-items: center;
286
+ justify-content: center;
287
+ background: rgba(10, 14, 23, 0.92);
288
+ border-radius: 0.75rem;
289
+ z-index: 999;
290
+ padding: 1.5rem;
291
+ pointer-events: auto;
292
+ }
293
+
294
+ #available-files-section .gradio-loading,
295
+ #available-files-grid .gradio-loading {
296
+ position: absolute;
297
+ top: 0;
298
+ left: 0;
299
+ right: 0;
300
+ bottom: 0;
301
+ width: 100%;
302
+ height: 100%;
303
+ display: flex;
304
+ align-items: center;
305
+ justify-content: center;
306
+ background: rgba(10, 14, 23, 0.92);
307
+ border-radius: 0.75rem;
308
+ z-index: 999;
309
+ padding: 1.5rem;
310
+ pointer-events: auto;
311
+ }
312
+
313
+ #available-files-section .gradio-loading > *,
314
+ #available-files-grid .gradio-loading > * {
315
+ width: 100%;
316
+ }
317
+
318
+ #available-files-section .gradio-loading progress,
319
+ #available-files-section .gradio-loading .progress-bar,
320
+ #available-files-section .gradio-loading .loading-progress,
321
+ #available-files-section .gradio-loading [role="progressbar"],
322
+ #available-files-section .gradio-loading .wrap,
323
+ #available-files-section .gradio-loading .inner,
324
+ #available-files-grid .gradio-loading progress,
325
+ #available-files-grid .gradio-loading .progress-bar,
326
+ #available-files-grid .gradio-loading .loading-progress,
327
+ #available-files-grid .gradio-loading [role="progressbar"],
328
+ #available-files-grid .gradio-loading .wrap,
329
+ #available-files-grid .gradio-loading .inner {
330
+ width: 100% !important;
331
+ max-width: none !important;
332
+ }
333
+
334
+ #available-files-section .gradio-loading .status,
335
+ #available-files-section .gradio-loading .message,
336
+ #available-files-section .gradio-loading .label,
337
+ #available-files-grid .gradio-loading .status,
338
+ #available-files-grid .gradio-loading .message,
339
+ #available-files-grid .gradio-loading .label {
340
+ text-align: center;
341
+ }
342
+
343
+ #date-browser-row {
344
+ gap: 0.75rem;
345
+ }
346
+
347
+ #date-browser-row .date-browser-column {
348
+ flex: 1 1 0%;
349
+ min-width: 0;
350
+ }
351
+
352
+ #date-browser-row .date-browser-column > .gradio-dropdown,
353
+ #date-browser-row .date-browser-column > .gradio-button {
354
+ width: 100%;
355
+ }
356
+
357
+ #date-browser-row .date-browser-column > .gradio-dropdown > div {
358
+ width: 100%;
359
+ }
360
+
361
+ #date-browser-row .date-browser-column .gradio-button {
362
+ justify-content: center;
363
+ }
364
+
365
+ #training-files-summary textarea {
366
+ max-height: 12rem;
367
+ overflow-y: auto;
368
+ }
369
+
370
+ #download-selected-button {
371
+ width: 100%;
372
+ position: relative;
373
+ z-index: 0;
374
+ }
375
+
376
+ #download-selected-button .gradio-button {
377
+ width: 100%;
378
+ justify-content: center;
379
+ }
380
+
381
+ #artifact-download-row {
382
+ gap: 0.75rem;
383
+ }
384
+
385
+ #artifact-download-row .artifact-download-button {
386
+ flex: 1 1 0%;
387
+ min-width: 0;
388
+ }
389
+
390
+ #artifact-download-row .artifact-download-button .gradio-button {
391
+ width: 100%;
392
+ justify-content: center;
393
+ }
394
+ """
395
+
396
+
397
+ def _github_cache_key(path: str) -> str:
398
+ return path or "__root__"
399
+
400
+
401
+ def _github_api_url(path: str) -> str:
402
+ clean_path = path.strip("/")
403
+ base = f"https://api.github.com/repos/{TRAINING_DATA_REPO}/contents"
404
+ if clean_path:
405
+ return f"{base}/{clean_path}?ref={TRAINING_DATA_BRANCH}"
406
+ return f"{base}?ref={TRAINING_DATA_BRANCH}"
407
+
408
+
409
+ def list_remote_directory(
410
+ path: str = "", *, force_refresh: bool = False
411
+ ) -> List[Dict[str, Any]]:
412
+ key = _github_cache_key(path)
413
+ if not force_refresh and key in GITHUB_CONTENT_CACHE:
414
+ return GITHUB_CONTENT_CACHE[key]
415
+
416
+ url = _github_api_url(path)
417
+ response = requests.get(url, timeout=30)
418
+ if response.status_code != 200:
419
+ raise RuntimeError(
420
+ f"GitHub API request failed for `{path or '.'}` (status {response.status_code})."
421
+ )
422
+
423
+ payload = response.json()
424
+ if not isinstance(payload, list):
425
+ raise RuntimeError(
426
+ "Unexpected GitHub API payload. Expected a directory listing."
427
+ )
428
+
429
+ GITHUB_CONTENT_CACHE[key] = payload
430
+ return payload
431
+
432
+
433
+ def list_remote_years(force_refresh: bool = False) -> List[str]:
434
+ entries = list_remote_directory("", force_refresh=force_refresh)
435
+ years = [item["name"] for item in entries if item.get("type") == "dir"]
436
+ return sorted(years)
437
+
438
+
439
+ def list_remote_months(year: str, *, force_refresh: bool = False) -> List[str]:
440
+ if not year:
441
+ return []
442
+ entries = list_remote_directory(year, force_refresh=force_refresh)
443
+ months = [item["name"] for item in entries if item.get("type") == "dir"]
444
+ return sorted(months)
445
+
446
+
447
+ def list_remote_days(
448
+ year: str, month: str, *, force_refresh: bool = False
449
+ ) -> List[str]:
450
+ if not year or not month:
451
+ return []
452
+ entries = list_remote_directory(f"{year}/{month}", force_refresh=force_refresh)
453
+ days = [item["name"] for item in entries if item.get("type") == "dir"]
454
+ return sorted(days)
455
+
456
+
457
+ def list_remote_files(
458
+ year: str, month: str, day: str, *, force_refresh: bool = False
459
+ ) -> List[str]:
460
+ if not year or not month or not day:
461
+ return []
462
+ entries = list_remote_directory(
463
+ f"{year}/{month}/{day}", force_refresh=force_refresh
464
+ )
465
+ files = [item["name"] for item in entries if item.get("type") == "file"]
466
+ return sorted(files)
467
+
468
+
469
+ def download_repository_file(year: str, month: str, day: str, filename: str) -> Path:
470
+ if not filename:
471
+ raise ValueError("Filename cannot be empty when downloading repository data.")
472
+
473
+ relative_parts = [part for part in (year, month, day, filename) if part]
474
+ if len(relative_parts) < 4:
475
+ raise ValueError("Provide year, month, day, and filename to download a CSV.")
476
+
477
+ relative_path = "/".join(relative_parts)
478
+ raw_url = (
479
+ f"https://raw.githubusercontent.com/{TRAINING_DATA_REPO}/"
480
+ f"{TRAINING_DATA_BRANCH}/{relative_path}"
481
+ )
482
+
483
+ response = requests.get(raw_url, stream=True, timeout=120)
484
+ if response.status_code != 200:
485
+ raise RuntimeError(
486
+ f"Failed to download `{relative_path}` (status {response.status_code})."
487
+ )
488
+
489
+ target_dir = TRAINING_DATA_DIR.joinpath(year, month, day)
490
+ target_dir.mkdir(parents=True, exist_ok=True)
491
+ target_path = target_dir / filename
492
+
493
+ with open(target_path, "wb") as handle:
494
+ for chunk in response.iter_content(chunk_size=1 << 20):
495
+ if chunk:
496
+ handle.write(chunk)
497
+
498
+ return target_path
499
+
500
+
501
+ def _normalise_header(name: str) -> str:
502
+ return str(name).strip().lower()
503
+
504
+
505
+ def guess_label_from_columns(
506
+ columns: Sequence[str], preferred: Optional[str] = None
507
+ ) -> Optional[str]:
508
+ if not columns:
509
+ return preferred
510
+
511
+ lookup = {_normalise_header(col): str(col) for col in columns}
512
+
513
+ if preferred:
514
+ preferred_stripped = preferred.strip()
515
+ for col in columns:
516
+ if str(col).strip() == preferred_stripped:
517
+ return str(col)
518
+ preferred_norm = _normalise_header(preferred)
519
+ if preferred_norm in lookup:
520
+ return lookup[preferred_norm]
521
+
522
+ for guess in TRAINING_LABEL_GUESSES:
523
+ guess_norm = _normalise_header(guess)
524
+ if guess_norm in lookup:
525
+ return lookup[guess_norm]
526
+
527
+ for col in columns:
528
+ if _normalise_header(col).startswith("fault"):
529
+ return str(col)
530
+
531
+ return str(columns[0])
532
+
533
+
534
+ def summarise_training_files(paths: Sequence[str], notes: Sequence[str]) -> str:
535
+ lines = [Path(path).name for path in paths]
536
+ lines.extend(notes)
537
+ return "\n".join(lines) if lines else "No training files available."
538
+
539
+
540
+ def read_training_status(status_file_path: str) -> str:
541
+ """Read the current training status from file."""
542
+ try:
543
+ if Path(status_file_path).exists():
544
+ with open(status_file_path, "r") as f:
545
+ return f.read().strip()
546
+ except Exception:
547
+ pass
548
+ return "Training status unavailable"
549
+
550
+
551
+ def _persist_uploaded_file(file_obj) -> Optional[Path]:
552
+ if file_obj is None:
553
+ return None
554
+
555
+ if isinstance(file_obj, (str, Path)):
556
+ source = Path(file_obj)
557
+ original_name = source.name
558
+ else:
559
+ source = Path(getattr(file_obj, "name", "") or getattr(file_obj, "path", ""))
560
+ original_name = getattr(file_obj, "orig_name", source.name) or source.name
561
+ if not source or not source.exists():
562
+ return None
563
+
564
+ original_name = Path(original_name).name or source.name
565
+
566
+ base_path = Path(original_name)
567
+ destination = TRAINING_UPLOAD_DIR / base_path.name
568
+ counter = 1
569
+ while destination.exists():
570
+ suffix = base_path.suffix or ".csv"
571
+ destination = TRAINING_UPLOAD_DIR / f"{base_path.stem}_{counter}{suffix}"
572
+ counter += 1
573
+
574
+ shutil.copy2(source, destination)
575
+ return destination
576
+
577
+
578
+ def prepare_training_paths(
579
+ paths: Sequence[str], current_label: str, cleanup_missing: bool = False
580
+ ):
581
+ valid_paths: List[str] = []
582
+ notes: List[str] = []
583
+ columns_map: Dict[str, str] = {}
584
+ for path in paths:
585
+ try:
586
+ df = load_measurement_csv(path)
587
+ except Exception as exc: # pragma: no cover - user file diagnostics
588
+ notes.append(f"⚠️ Skipped {Path(path).name}: {exc}")
589
+ if cleanup_missing:
590
+ try:
591
+ Path(path).unlink(missing_ok=True)
592
+ except Exception:
593
+ pass
594
+ continue
595
+ valid_paths.append(str(path))
596
+ for col in df.columns:
597
+ columns_map[_normalise_header(col)] = str(col)
598
+
599
+ summary = summarise_training_files(valid_paths, notes)
600
+ preferred = current_label or LABEL_COLUMN
601
+ dropdown_choices = (
602
+ sorted(columns_map.values()) if columns_map else [preferred or LABEL_COLUMN]
603
+ )
604
+ guessed = guess_label_from_columns(dropdown_choices, preferred)
605
+ dropdown_value = guessed or preferred or LABEL_COLUMN
606
+
607
+ return (
608
+ valid_paths,
609
+ summary,
610
+ gr.update(choices=dropdown_choices, value=dropdown_value),
611
+ )
612
+
613
+
614
+ def append_training_files(new_files, existing_paths: Sequence[str], current_label: str):
615
+ if isinstance(existing_paths, (str, Path)):
616
+ paths: List[str] = [str(existing_paths)]
617
+ elif existing_paths is None:
618
+ paths = []
619
+ else:
620
+ paths = list(existing_paths)
621
+ if new_files:
622
+ for file in new_files:
623
+ persisted = _persist_uploaded_file(file)
624
+ if persisted is None:
625
+ continue
626
+ path_str = str(persisted)
627
+ if path_str not in paths:
628
+ paths.append(path_str)
629
+
630
+ return prepare_training_paths(paths, current_label, cleanup_missing=True)
631
+
632
+
633
+ def load_repository_training_files(current_label: str, force_refresh: bool = False):
634
+ if force_refresh:
635
+ # Clearing the cache is enough because downloads are now on-demand.
636
+ for cached in list(TRAINING_DATA_DIR.glob("*")):
637
+ # On refresh we keep previously downloaded files; no deletion required.
638
+ # The flag triggers downstream UI updates only.
639
+ break
640
+
641
+ csv_paths = sorted(
642
+ str(path) for path in TRAINING_DATA_DIR.rglob("*.csv") if path.is_file()
643
+ )
644
+ if not csv_paths:
645
+ message = (
646
+ "No local database CSVs are available yet. Use the database browser "
647
+ "below to download specific days before training."
648
+ )
649
+ default_label = current_label or LABEL_COLUMN or "Fault"
650
+ return (
651
+ [],
652
+ message,
653
+ gr.update(choices=[default_label], value=default_label),
654
+ message,
655
+ )
656
+
657
+ valid_paths, summary, label_update = prepare_training_paths(
658
+ csv_paths, current_label, cleanup_missing=False
659
+ )
660
+
661
+ info = (
662
+ f"Ready with {len(valid_paths)} CSV file(s) cached locally under "
663
+ f"the database cache `{TRAINING_DATA_DIR}`."
664
+ )
665
+
666
+ return valid_paths, summary, label_update, info
667
+
668
+
669
+ def refresh_remote_browser(force_refresh: bool = False):
670
+ if force_refresh:
671
+ GITHUB_CONTENT_CACHE.clear()
672
+ try:
673
+ years = list_remote_years(force_refresh=force_refresh)
674
+ if years:
675
+ message = "Select a year, month, and day to list available CSV files."
676
+ else:
677
+ message = (
678
+ "⚠️ No directories were found in the database root. Verify the upstream "
679
+ "structure."
680
+ )
681
+ return (
682
+ gr.update(choices=years, value=None),
683
+ gr.update(choices=[], value=None),
684
+ gr.update(choices=[], value=None),
685
+ gr.update(choices=[], value=[]),
686
+ message,
687
+ )
688
+ except Exception as exc:
689
+ return (
690
+ gr.update(choices=[], value=None),
691
+ gr.update(choices=[], value=None),
692
+ gr.update(choices=[], value=None),
693
+ gr.update(choices=[], value=[]),
694
+ f"⚠️ Failed to query database: {exc}",
695
+ )
696
+
697
+
698
+ def on_year_change(year: Optional[str]):
699
+ if not year:
700
+ return (
701
+ gr.update(choices=[], value=None),
702
+ gr.update(choices=[], value=None),
703
+ gr.update(choices=[], value=[]),
704
+ "Select a year to continue.",
705
+ )
706
+ try:
707
+ months = list_remote_months(year)
708
+ message = (
709
+ f"Year `{year}` selected. Choose a month to drill down."
710
+ if months
711
+ else f"⚠️ No months available under `{year}`."
712
+ )
713
+ return (
714
+ gr.update(choices=months, value=None),
715
+ gr.update(choices=[], value=None),
716
+ gr.update(choices=[], value=[]),
717
+ message,
718
+ )
719
+ except Exception as exc:
720
+ return (
721
+ gr.update(choices=[], value=None),
722
+ gr.update(choices=[], value=None),
723
+ gr.update(choices=[], value=[]),
724
+ f"⚠️ Failed to list months: {exc}",
725
+ )
726
+
727
+
728
+ def on_month_change(year: Optional[str], month: Optional[str]):
729
+ if not year or not month:
730
+ return (
731
+ gr.update(choices=[], value=None),
732
+ gr.update(choices=[], value=[]),
733
+ "Select a month to continue.",
734
+ )
735
+ try:
736
+ days = list_remote_days(year, month)
737
+ message = (
738
+ f"Month `{year}/{month}` ready. Pick a day to view files."
739
+ if days
740
+ else f"⚠️ No day folders found under `{year}/{month}`."
741
+ )
742
+ return (
743
+ gr.update(choices=days, value=None),
744
+ gr.update(choices=[], value=[]),
745
+ message,
746
+ )
747
+ except Exception as exc:
748
+ return (
749
+ gr.update(choices=[], value=None),
750
+ gr.update(choices=[], value=[]),
751
+ f"⚠️ Failed to list days: {exc}",
752
+ )
753
+
754
+
755
+ def on_day_change(year: Optional[str], month: Optional[str], day: Optional[str]):
756
+ if not year or not month or not day:
757
+ return (
758
+ gr.update(choices=[], value=[]),
759
+ "Select a day to load file names.",
760
+ )
761
+ try:
762
+ files = list_remote_files(year, month, day)
763
+ message = (
764
+ f"{len(files)} file(s) available for `{year}/{month}/{day}`."
765
+ if files
766
+ else f"⚠️ No CSV files found under `{year}/{month}/{day}`."
767
+ )
768
+ return (
769
+ gr.update(choices=files, value=[]),
770
+ message,
771
+ )
772
+ except Exception as exc:
773
+ return (
774
+ gr.update(choices=[], value=[]),
775
+ f"⚠️ Failed to list files: {exc}",
776
+ )
777
+
778
+
779
+ def download_selected_files(
780
+ year: Optional[str],
781
+ month: Optional[str],
782
+ day: Optional[str],
783
+ filenames: Sequence[str],
784
+ current_label: str,
785
+ ):
786
+ if not filenames:
787
+ message = "Select at least one CSV before downloading."
788
+ local = load_repository_training_files(current_label)
789
+ return (*local, gr.update(), message)
790
+
791
+ success: List[str] = []
792
+ notes: List[str] = []
793
+ for filename in filenames:
794
+ try:
795
+ path = download_repository_file(
796
+ year or "", month or "", day or "", filename
797
+ )
798
+ success.append(str(path))
799
+ except Exception as exc:
800
+ notes.append(f"⚠️ {filename}: {exc}")
801
+
802
+ local = load_repository_training_files(current_label)
803
+
804
+ message_lines = []
805
+ if success:
806
+ message_lines.append(
807
+ f"Downloaded {len(success)} file(s) to the database cache `{TRAINING_DATA_DIR}`."
808
+ )
809
+ if notes:
810
+ message_lines.extend(notes)
811
+ if not message_lines:
812
+ message_lines.append("No files were downloaded.")
813
+
814
+ return (*local, gr.update(value=[]), "\n".join(message_lines))
815
+
816
+
817
+ def download_day_bundle(
818
+ year: Optional[str],
819
+ month: Optional[str],
820
+ day: Optional[str],
821
+ current_label: str,
822
+ ):
823
+ if not (year and month and day):
824
+ local = load_repository_training_files(current_label)
825
+ return (
826
+ *local,
827
+ gr.update(),
828
+ "Select a year, month, and day before downloading an entire day.",
829
+ )
830
+
831
+ try:
832
+ files = list_remote_files(year, month, day)
833
+ except Exception as exc:
834
+ local = load_repository_training_files(current_label)
835
+ return (
836
+ *local,
837
+ gr.update(),
838
+ f"⚠️ Failed to list CSVs for `{year}/{month}/{day}`: {exc}",
839
+ )
840
+
841
+ if not files:
842
+ local = load_repository_training_files(current_label)
843
+ return (
844
+ *local,
845
+ gr.update(),
846
+ f"No CSV files were found for `{year}/{month}/{day}`.",
847
+ )
848
+
849
+ result = list(download_selected_files(year, month, day, files, current_label))
850
+ result[-1] = (
851
+ f"Downloaded all {len(files)} CSV file(s) for `{year}/{month}/{day}`.\n"
852
+ f"{result[-1]}"
853
+ )
854
+ return tuple(result)
855
+
856
+
857
+ def download_month_bundle(
858
+ year: Optional[str], month: Optional[str], current_label: str
859
+ ):
860
+ if not (year and month):
861
+ local = load_repository_training_files(current_label)
862
+ return (
863
+ *local,
864
+ gr.update(),
865
+ "Select a year and month before downloading an entire month.",
866
+ )
867
+
868
+ try:
869
+ days = list_remote_days(year, month)
870
+ except Exception as exc:
871
+ local = load_repository_training_files(current_label)
872
+ return (
873
+ *local,
874
+ gr.update(),
875
+ f"⚠️ Failed to enumerate days for `{year}/{month}`: {exc}",
876
+ )
877
+
878
+ if not days:
879
+ local = load_repository_training_files(current_label)
880
+ return (
881
+ *local,
882
+ gr.update(),
883
+ f"No day folders were found for `{year}/{month}`.",
884
+ )
885
+
886
+ downloaded = 0
887
+ notes: List[str] = []
888
+ for day in days:
889
+ try:
890
+ files = list_remote_files(year, month, day)
891
+ except Exception as exc:
892
+ notes.append(f"⚠️ Failed to list `{year}/{month}/{day}`: {exc}")
893
+ continue
894
+ if not files:
895
+ notes.append(f"⚠️ No CSV files in `{year}/{month}/{day}`.")
896
+ continue
897
+ for filename in files:
898
+ try:
899
+ download_repository_file(year, month, day, filename)
900
+ downloaded += 1
901
+ except Exception as exc:
902
+ notes.append(f"⚠️ {year}/{month}/{day}/{filename}: {exc}")
903
+
904
+ local = load_repository_training_files(current_label)
905
+ message_lines = []
906
+ if downloaded:
907
+ message_lines.append(
908
+ f"Downloaded {downloaded} CSV file(s) for `{year}/{month}` into the "
909
+ f"database cache `{TRAINING_DATA_DIR}`."
910
+ )
911
+ message_lines.extend(notes)
912
+ if not message_lines:
913
+ message_lines.append("No files were downloaded.")
914
+
915
+ return (*local, gr.update(value=[]), "\n".join(message_lines))
916
+
917
+
918
+ def download_year_bundle(year: Optional[str], current_label: str):
919
+ if not year:
920
+ local = load_repository_training_files(current_label)
921
+ return (
922
+ *local,
923
+ gr.update(),
924
+ "Select a year before downloading an entire year of CSVs.",
925
+ )
926
+
927
+ try:
928
+ months = list_remote_months(year)
929
+ except Exception as exc:
930
+ local = load_repository_training_files(current_label)
931
+ return (
932
+ *local,
933
+ gr.update(),
934
+ f"⚠️ Failed to enumerate months for `{year}`: {exc}",
935
+ )
936
+
937
+ if not months:
938
+ local = load_repository_training_files(current_label)
939
+ return (
940
+ *local,
941
+ gr.update(),
942
+ f"No month folders were found for `{year}`.",
943
+ )
944
+
945
+ downloaded = 0
946
+ notes: List[str] = []
947
+ for month in months:
948
+ try:
949
+ days = list_remote_days(year, month)
950
+ except Exception as exc:
951
+ notes.append(f"⚠️ Failed to list `{year}/{month}`: {exc}")
952
+ continue
953
+ if not days:
954
+ notes.append(f"⚠️ No day folders in `{year}/{month}`.")
955
+ continue
956
+ for day in days:
957
+ try:
958
+ files = list_remote_files(year, month, day)
959
+ except Exception as exc:
960
+ notes.append(f"⚠️ Failed to list `{year}/{month}/{day}`: {exc}")
961
+ continue
962
+ if not files:
963
+ notes.append(f"⚠️ No CSV files in `{year}/{month}/{day}`.")
964
+ continue
965
+ for filename in files:
966
+ try:
967
+ download_repository_file(year, month, day, filename)
968
+ downloaded += 1
969
+ except Exception as exc:
970
+ notes.append(f"⚠️ {year}/{month}/{day}/{filename}: {exc}")
971
+
972
+ local = load_repository_training_files(current_label)
973
+ message_lines = []
974
+ if downloaded:
975
+ message_lines.append(
976
+ f"Downloaded {downloaded} CSV file(s) for `{year}` into the "
977
+ f"database cache `{TRAINING_DATA_DIR}`."
978
+ )
979
+ message_lines.extend(notes)
980
+ if not message_lines:
981
+ message_lines.append("No files were downloaded.")
982
+
983
+ return (*local, gr.update(value=[]), "\n".join(message_lines))
984
+
985
+
986
+ def clear_downloaded_cache(current_label: str):
987
+ status_message = ""
988
+ try:
989
+ if TRAINING_DATA_DIR.exists():
990
+ shutil.rmtree(TRAINING_DATA_DIR)
991
+ TRAINING_DATA_DIR.mkdir(parents=True, exist_ok=True)
992
+ status_message = (
993
+ f"Cleared all downloaded CSVs from database cache `{TRAINING_DATA_DIR}`."
994
+ )
995
+ except Exception as exc:
996
+ status_message = f"⚠️ Failed to clear database cache: {exc}"
997
+
998
+ local = load_repository_training_files(current_label, force_refresh=True)
999
+ remote = list(refresh_remote_browser(force_refresh=False))
1000
+ if status_message:
1001
+ previous = remote[-1]
1002
+ if isinstance(previous, str) and previous:
1003
+ remote[-1] = f"{status_message}\n{previous}"
1004
+ else:
1005
+ remote[-1] = status_message
1006
+
1007
+ return (*local, *remote)
1008
+
1009
+
1010
+ def normalise_output_directory(directory: Optional[str]) -> Path:
1011
+ base = Path(directory or MODEL_OUTPUT_DIR)
1012
+ base = base.expanduser()
1013
+ if not base.is_absolute():
1014
+ base = (Path.cwd() / base).resolve()
1015
+ return base
1016
+
1017
+
1018
+ def resolve_output_path(
1019
+ directory: Optional[Union[Path, str]], filename: Optional[str], fallback: str
1020
+ ) -> Path:
1021
+ if isinstance(directory, Path):
1022
+ base = directory
1023
+ else:
1024
+ base = normalise_output_directory(directory)
1025
+ candidate = Path(filename or "").expanduser()
1026
+ if str(candidate):
1027
+ if candidate.is_absolute():
1028
+ return candidate
1029
+ return (base / candidate).resolve()
1030
+ return (base / fallback).resolve()
1031
+
1032
+
1033
+ ARTIFACT_FILE_EXTENSIONS: Tuple[str, ...] = (
1034
+ ".keras",
1035
+ ".h5",
1036
+ ".joblib",
1037
+ ".pkl",
1038
+ ".json",
1039
+ ".onnx",
1040
+ ".zip",
1041
+ ".txt",
1042
+ )
1043
+
1044
+
1045
+ def gather_directory_choices(current: Optional[str]) -> Tuple[List[str], str]:
1046
+ base = normalise_output_directory(current or str(MODEL_OUTPUT_DIR))
1047
+ candidates = {str(base)}
1048
+ try:
1049
+ for candidate in base.parent.iterdir():
1050
+ if candidate.is_dir():
1051
+ candidates.add(str(candidate.resolve()))
1052
+ except Exception:
1053
+ pass
1054
+ return sorted(candidates), str(base)
1055
+
1056
+
1057
+ def gather_artifact_choices(
1058
+ directory: Optional[str], selection: Optional[str] = None
1059
+ ) -> Tuple[List[Tuple[str, str]], Optional[str]]:
1060
+ base = normalise_output_directory(directory)
1061
+ choices: List[Tuple[str, str]] = []
1062
+ selected_value: Optional[str] = None
1063
+ if base.exists():
1064
+ try:
1065
+ artifacts = sorted(
1066
+ [
1067
+ path
1068
+ for path in base.iterdir()
1069
+ if path.is_file()
1070
+ and (
1071
+ not ARTIFACT_FILE_EXTENSIONS
1072
+ or path.suffix.lower() in ARTIFACT_FILE_EXTENSIONS
1073
+ )
1074
+ ],
1075
+ key=lambda path: path.name.lower(),
1076
+ )
1077
+ choices = [(artifact.name, str(artifact)) for artifact in artifacts]
1078
+ except Exception:
1079
+ choices = []
1080
+
1081
+ if selection and any(value == selection for _, value in choices):
1082
+ selected_value = selection
1083
+ elif choices:
1084
+ selected_value = choices[0][1]
1085
+
1086
+ return choices, selected_value
1087
+
1088
+
1089
+ def download_button_state(path: Optional[Union[str, Path]]):
1090
+ if not path:
1091
+ return gr.update(value=None, visible=False)
1092
+ candidate = Path(path)
1093
+ if candidate.exists():
1094
+ return gr.update(value=str(candidate), visible=True)
1095
+ return gr.update(value=None, visible=False)
1096
+
1097
+
1098
+ def clear_training_files():
1099
+ default_label = LABEL_COLUMN or "Fault"
1100
+ for cached_file in TRAINING_UPLOAD_DIR.glob("*"):
1101
+ try:
1102
+ if cached_file.is_file():
1103
+ cached_file.unlink(missing_ok=True)
1104
+ except Exception:
1105
+ pass
1106
+ return (
1107
+ [],
1108
+ "No training files selected.",
1109
+ gr.update(choices=[default_label], value=default_label),
1110
+ gr.update(value=None),
1111
+ )
1112
+
1113
+
1114
+ PROJECT_OVERVIEW_MD = """
1115
+ ## Project Overview
1116
+
1117
+ This project focuses on classifying faults in electrical transmission lines and
1118
+ grid-connected photovoltaic (PV) systems by combining ensemble learning
1119
+ techniques with deep neural architectures.
1120
+
1121
+ ## Datasets
1122
+
1123
+ ### Transmission Line Fault Dataset
1124
+ - 134,406 samples collected from Phasor Measurement Units (PMUs)
1125
+ - 14 monitored channels covering currents, voltages, magnitudes, frequency, and phase angles
1126
+ - Labels span symmetrical and asymmetrical faults: NF, L-G, LL, LL-G, LLL, and LLL-G
1127
+ - Time span: 0 to 5.7 seconds with high-frequency sampling
1128
+
1129
+ ### Grid-Connected PV System Fault Dataset
1130
+ - 2,163,480 samples from 16 experimental scenarios
1131
+ - 14 features including PV array measurements (Ipv, Vpv, Vdc), three-phase currents/voltages, aggregate magnitudes (Iabc, Vabc), and frequency indicators (If, Vf)
1132
+ - Captures array, inverter, grid anomaly, feedback sensor, and MPPT controller faults at 9.9989 μs sampling intervals
1133
+
1134
+ ## Data Format Quick Reference
1135
+
1136
+ Each measurement file may be comma or tab separated and typically exposes the
1137
+ following ordered columns:
1138
+
1139
+ 1. `Timestamp`
1140
+ 2. `[325] UPMU_SUB22:FREQ` – system frequency (Hz)
1141
+ 3. `[326] UPMU_SUB22:DFDT` – frequency rate-of-change
1142
+ 4. `[327] UPMU_SUB22:FLAG` – PMU status flag
1143
+ 5. `[328] UPMU_SUB22-L1:MAG` – phase A voltage magnitude
1144
+ 6. `[329] UPMU_SUB22-L1:ANG` – phase A voltage angle
1145
+ 7. `[330] UPMU_SUB22-L2:MAG` – phase B voltage magnitude
1146
+ 8. `[331] UPMU_SUB22-L2:ANG` – phase B voltage angle
1147
+ 9. `[332] UPMU_SUB22-L3:MAG` – phase C voltage magnitude
1148
+ 10. `[333] UPMU_SUB22-L3:ANG` – phase C voltage angle
1149
+ 11. `[334] UPMU_SUB22-C1:MAG` – phase A current magnitude
1150
+ 12. `[335] UPMU_SUB22-C1:ANG` – phase A current angle
1151
+ 13. `[336] UPMU_SUB22-C2:MAG` – phase B current magnitude
1152
+ 14. `[337] UPMU_SUB22-C2:ANG` – phase B current angle
1153
+ 15. `[338] UPMU_SUB22-C3:MAG` – phase C current magnitude
1154
+ 16. `[339] UPMU_SUB22-C3:ANG` – phase C current angle
1155
+
1156
+ The training tab automatically downloads the latest CSV exports from the
1157
+ `VincentCroft/ThesisModelData` repository and concatenates them before building
1158
+ sliding windows.
1159
+
1160
+ ## Models Developed
1161
+
1162
+ 1. **Support Vector Machine (SVM)** – provides the classical machine learning baseline with balanced accuracy across both datasets (85% PMU / 83% PV).
1163
+ 2. **CNN-LSTM** – couples convolutional feature extraction with temporal memory, achieving 92% PMU / 89% PV accuracy.
1164
+ 3. **Temporal Convolutional Network (TCN)** – leverages dilated convolutions for long-range context and delivers the best trade-off between accuracy and training time (94% PMU / 91% PV).
1165
+
1166
+ ## Results Summary
1167
+
1168
+ - **Transmission Line Fault Classification**: SVM 85%, CNN-LSTM 92%, TCN 94%
1169
+ - **PV System Fault Classification**: SVM 83%, CNN-LSTM 89%, TCN 91%
1170
+
1171
+ Use the **Inference** tab to score new PMU/PV windows and the **Training** tab to
1172
+ fine-tune or retrain any of the supported models directly within Hugging Face
1173
+ Spaces. The logs panel will surface TensorBoard archives whenever deep-learning
1174
+ models are trained.
1175
+ """
1176
+
1177
+
1178
+ def load_measurement_csv(path: str) -> pd.DataFrame:
1179
+ """Read a PMU/PV measurement file with flexible separators and column mapping."""
1180
+
1181
+ try:
1182
+ df = pd.read_csv(path, sep=None, engine="python", encoding="utf-8-sig")
1183
+ except Exception:
1184
+ df = None
1185
+ for separator in ("\t", ",", ";"):
1186
+ try:
1187
+ df = pd.read_csv(
1188
+ path, sep=separator, engine="python", encoding="utf-8-sig"
1189
+ )
1190
+ break
1191
+ except Exception:
1192
+ df = None
1193
+ if df is None:
1194
+ raise
1195
+
1196
+ # Clean column names
1197
+ df.columns = [str(col).strip() for col in df.columns]
1198
+
1199
+ print(f"Loaded CSV with {len(df)} rows and {len(df.columns)} columns")
1200
+ print(f"Columns: {list(df.columns)}")
1201
+ print(f"Data shape: {df.shape}")
1202
+
1203
+ # Check if we have enough data for training
1204
+ if len(df) < 100:
1205
+ print(
1206
+ f"Warning: Only {len(df)} rows of data. Recommend at least 1000 rows for effective training."
1207
+ )
1208
+
1209
+ # Check for label column
1210
+ has_label = any(
1211
+ col.lower() in ["fault", "label", "class", "target"] for col in df.columns
1212
+ )
1213
+ if not has_label:
1214
+ print(
1215
+ "Warning: No label column found. Adding dummy 'Fault' column with value 'Normal' for all samples."
1216
+ )
1217
+ df["Fault"] = "Normal" # Add dummy label for training
1218
+
1219
+ # Create column mapping - map similar column names to expected format
1220
+ column_mapping = {}
1221
+ expected_cols = list(REQUIRED_PMU_COLUMNS)
1222
+
1223
+ # If we have at least the right number of numeric columns after Timestamp, use positional mapping
1224
+ if "Timestamp" in df.columns:
1225
+ numeric_cols = [col for col in df.columns if col != "Timestamp"]
1226
+ if len(numeric_cols) >= len(expected_cols):
1227
+ # Map by position (after Timestamp)
1228
+ for i, expected_col in enumerate(expected_cols):
1229
+ if i < len(numeric_cols):
1230
+ column_mapping[numeric_cols[i]] = expected_col
1231
+
1232
+ # Rename columns to match expected format
1233
+ df = df.rename(columns=column_mapping)
1234
+
1235
+ # Check if we have the required columns after mapping
1236
+ missing = [col for col in REQUIRED_PMU_COLUMNS if col not in df.columns]
1237
+ if missing:
1238
+ # If still missing, try a more flexible approach
1239
+ available_numeric = df.select_dtypes(include=[np.number]).columns.tolist()
1240
+ if len(available_numeric) >= len(expected_cols):
1241
+ # Use the first N numeric columns
1242
+ for i, expected_col in enumerate(expected_cols):
1243
+ if i < len(available_numeric):
1244
+ if available_numeric[i] not in df.columns:
1245
+ continue
1246
+ df = df.rename(columns={available_numeric[i]: expected_col})
1247
+
1248
+ # Recheck missing columns
1249
+ missing = [col for col in REQUIRED_PMU_COLUMNS if col not in df.columns]
1250
+
1251
+ if missing:
1252
+ missing_str = ", ".join(missing)
1253
+ available_str = ", ".join(df.columns.tolist())
1254
+ raise ValueError(
1255
+ f"Missing required PMU feature columns: {missing_str}. "
1256
+ f"Available columns: {available_str}. "
1257
+ "Please ensure your CSV has the correct format with Timestamp followed by PMU measurements."
1258
+ )
1259
+
1260
+ return df
1261
+
1262
+
1263
+ def apply_metadata(metadata: Dict[str, Any]) -> None:
1264
+ global FEATURE_COLUMNS, LABEL_CLASSES, LABEL_COLUMN, SEQUENCE_LENGTH, DEFAULT_WINDOW_STRIDE, MODEL_TYPE, MODEL_FORMAT
1265
+ FEATURE_COLUMNS = [
1266
+ str(col) for col in metadata.get("feature_columns", DEFAULT_FEATURE_COLUMNS)
1267
+ ]
1268
+ LABEL_CLASSES = [str(label) for label in metadata.get("label_classes", [])]
1269
+ LABEL_COLUMN = str(metadata.get("label_column", "Fault"))
1270
+ SEQUENCE_LENGTH = int(metadata.get("sequence_length", DEFAULT_SEQUENCE_LENGTH))
1271
+ DEFAULT_WINDOW_STRIDE = int(metadata.get("stride", DEFAULT_STRIDE))
1272
+ MODEL_TYPE = str(metadata.get("model_type", "cnn_lstm")).lower()
1273
+ MODEL_FORMAT = str(
1274
+ metadata.get("model_format", "joblib" if MODEL_TYPE == "svm" else "keras")
1275
+ ).lower()
1276
+
1277
+
1278
+ apply_metadata(METADATA)
1279
+
1280
+
1281
+ def sync_label_classes_from_model(model: Optional[object]) -> None:
1282
+ global LABEL_CLASSES
1283
+ if model is None:
1284
+ return
1285
+ if hasattr(model, "classes_"):
1286
+ LABEL_CLASSES = [str(label) for label in getattr(model, "classes_")]
1287
+ elif not LABEL_CLASSES and hasattr(model, "output_shape"):
1288
+ LABEL_CLASSES = [str(i) for i in range(int(model.output_shape[-1]))]
1289
+
1290
+
1291
+ # Load model and scaler with error handling
1292
+ print("Loading model and scaler...")
1293
+ try:
1294
+ MODEL = try_load_model(MODEL_PATH, MODEL_TYPE, MODEL_FORMAT)
1295
+ print(f"Model loaded: {MODEL is not None}")
1296
+ except Exception as e:
1297
+ print(f"Model loading failed: {e}")
1298
+ MODEL = None
1299
+
1300
+ try:
1301
+ SCALER = try_load_scaler(SCALER_PATH)
1302
+ print(f"Scaler loaded: {SCALER is not None}")
1303
+ except Exception as e:
1304
+ print(f"Scaler loading failed: {e}")
1305
+ SCALER = None
1306
+
1307
+ try:
1308
+ sync_label_classes_from_model(MODEL)
1309
+ print("Label classes synchronized")
1310
+ except Exception as e:
1311
+ print(f"Label sync failed: {e}")
1312
+
1313
+ print("Application initialization completed.")
1314
+ print(
1315
+ f"Ready to start Gradio interface. Model available: {MODEL is not None}, Scaler available: {SCALER is not None}"
1316
+ )
1317
+
1318
+
1319
+ def refresh_artifacts(model_path: Path, scaler_path: Path, metadata_path: Path) -> None:
1320
+ global MODEL_PATH, SCALER_PATH, METADATA_PATH, MODEL, SCALER, METADATA
1321
+ MODEL_PATH = model_path
1322
+ SCALER_PATH = scaler_path
1323
+ METADATA_PATH = metadata_path
1324
+ METADATA = load_metadata(metadata_path)
1325
+ apply_metadata(METADATA)
1326
+ MODEL = try_load_model(model_path, MODEL_TYPE, MODEL_FORMAT)
1327
+ SCALER = try_load_scaler(scaler_path)
1328
+ sync_label_classes_from_model(MODEL)
1329
+
1330
+
1331
+ # --------------------------------------------------------------------------------------
1332
+ # Pre-processing helpers
1333
+ # --------------------------------------------------------------------------------------
1334
+
1335
+
1336
+ def ensure_ready():
1337
+ if MODEL is None or SCALER is None:
1338
+ raise RuntimeError(
1339
+ "The model and feature scaler are not available. Upload the trained model "
1340
+ "(for example `pmu_cnn_lstm_model.keras`, `pmu_tcn_model.keras`, or `pmu_svm_model.joblib`), "
1341
+ "the feature scaler (`pmu_feature_scaler.pkl`), and the metadata JSON (`pmu_metadata.json`) to the Space root "
1342
+ "or configure the Hugging Face Hub environment variables so the artifacts can be downloaded "
1343
+ "automatically."
1344
+ )
1345
+
1346
+
1347
+ def parse_text_features(text: str) -> np.ndarray:
1348
+ cleaned = re.sub(r"[;\n\t]+", ",", text.strip())
1349
+ arr = np.fromstring(cleaned, sep=",")
1350
+ if arr.size == 0:
1351
+ raise ValueError(
1352
+ "No feature values were parsed. Please enter comma-separated numbers."
1353
+ )
1354
+ return arr.astype(np.float32)
1355
+
1356
+
1357
+ def apply_scaler(sequences: np.ndarray) -> np.ndarray:
1358
+ if SCALER is None:
1359
+ return sequences
1360
+ shape = sequences.shape
1361
+ flattened = sequences.reshape(-1, shape[-1])
1362
+ scaled = SCALER.transform(flattened)
1363
+ return scaled.reshape(shape)
1364
+
1365
+
1366
+ def make_sliding_windows(
1367
+ data: np.ndarray, sequence_length: int, stride: int
1368
+ ) -> np.ndarray:
1369
+ if data.shape[0] < sequence_length:
1370
+ raise ValueError(
1371
+ f"The dataset contains {data.shape[0]} rows which is less than the requested sequence "
1372
+ f"length {sequence_length}. Provide more samples or reduce the sequence length."
1373
+ )
1374
+ windows = [
1375
+ data[start : start + sequence_length]
1376
+ for start in range(0, data.shape[0] - sequence_length + 1, stride)
1377
+ ]
1378
+ return np.stack(windows)
1379
+
1380
+
1381
+ def dataframe_to_sequences(
1382
+ df: pd.DataFrame,
1383
+ *,
1384
+ sequence_length: int,
1385
+ stride: int,
1386
+ feature_columns: Sequence[str],
1387
+ drop_label: bool = True,
1388
+ ) -> np.ndarray:
1389
+ work_df = df.copy()
1390
+ if drop_label and LABEL_COLUMN in work_df.columns:
1391
+ work_df = work_df.drop(columns=[LABEL_COLUMN])
1392
+ if "Timestamp" in work_df.columns:
1393
+ work_df = work_df.sort_values("Timestamp")
1394
+
1395
+ available_cols = [c for c in feature_columns if c in work_df.columns]
1396
+ n_features = len(feature_columns)
1397
+ if available_cols and len(available_cols) == n_features:
1398
+ array = work_df[available_cols].astype(np.float32).to_numpy()
1399
+ return make_sliding_windows(array, sequence_length, stride)
1400
+
1401
+ numeric_df = work_df.select_dtypes(include=[np.number])
1402
+ array = numeric_df.astype(np.float32).to_numpy()
1403
+ if array.shape[1] == n_features * sequence_length:
1404
+ return array.reshape(array.shape[0], sequence_length, n_features)
1405
+ if sequence_length == 1 and array.shape[1] == n_features:
1406
+ return array.reshape(array.shape[0], 1, n_features)
1407
+ raise ValueError(
1408
+ "CSV columns do not match the expected feature layout. Include the full PMU feature set "
1409
+ "or provide pre-shaped sliding window data."
1410
+ )
1411
+
1412
+
1413
+ def label_name(index: int) -> str:
1414
+ if 0 <= index < len(LABEL_CLASSES):
1415
+ return str(LABEL_CLASSES[index])
1416
+ return f"class_{index}"
1417
+
1418
+
1419
+ def format_predictions(probabilities: np.ndarray) -> pd.DataFrame:
1420
+ rows: List[Dict[str, object]] = []
1421
+ order = np.argsort(probabilities, axis=1)[:, ::-1]
1422
+ for idx, (prob_row, ranking) in enumerate(zip(probabilities, order)):
1423
+ top_idx = int(ranking[0])
1424
+ top_label = label_name(top_idx)
1425
+ top_conf = float(prob_row[top_idx])
1426
+ top3 = [f"{label_name(i)} ({prob_row[i]*100:.2f}%)" for i in ranking[:3]]
1427
+ rows.append(
1428
+ {
1429
+ "window": idx,
1430
+ "predicted_label": top_label,
1431
+ "confidence": round(top_conf, 4),
1432
+ "top3": " | ".join(top3),
1433
+ }
1434
+ )
1435
+ return pd.DataFrame(rows)
1436
+
1437
+
1438
+ def probabilities_to_json(probabilities: np.ndarray) -> List[Dict[str, object]]:
1439
+ payload: List[Dict[str, object]] = []
1440
+ for idx, prob_row in enumerate(probabilities):
1441
+ payload.append(
1442
+ {
1443
+ "window": int(idx),
1444
+ "probabilities": {
1445
+ label_name(i): float(prob_row[i]) for i in range(prob_row.shape[0])
1446
+ },
1447
+ }
1448
+ )
1449
+ return payload
1450
+
1451
+
1452
+ def predict_sequences(
1453
+ sequences: np.ndarray,
1454
+ ) -> Tuple[str, pd.DataFrame, List[Dict[str, object]]]:
1455
+ ensure_ready()
1456
+ sequences = apply_scaler(sequences.astype(np.float32))
1457
+ if MODEL_TYPE == "svm":
1458
+ flattened = sequences.reshape(sequences.shape[0], -1)
1459
+ if hasattr(MODEL, "predict_proba"):
1460
+ probs = MODEL.predict_proba(flattened)
1461
+ else:
1462
+ raise RuntimeError(
1463
+ "Loaded SVM model does not expose predict_proba. Retrain with probability=True."
1464
+ )
1465
+ else:
1466
+ probs = MODEL.predict(sequences, verbose=0)
1467
+ table = format_predictions(probs)
1468
+ json_probs = probabilities_to_json(probs)
1469
+ architecture = MODEL_TYPE.replace("_", "-").upper()
1470
+ status = f"Generated {len(sequences)} windows. {architecture} model output dimension: {probs.shape[1]}."
1471
+ return status, table, json_probs
1472
+
1473
+
1474
+ def predict_from_text(
1475
+ text: str, sequence_length: int
1476
+ ) -> Tuple[str, pd.DataFrame, List[Dict[str, object]]]:
1477
+ arr = parse_text_features(text)
1478
+ n_features = len(FEATURE_COLUMNS)
1479
+ if arr.size % n_features != 0:
1480
+ raise ValueError(
1481
+ f"The number of values ({arr.size}) is not a multiple of the feature dimension "
1482
+ f"({n_features}). Provide values in groups of {n_features}."
1483
+ )
1484
+ timesteps = arr.size // n_features
1485
+ if timesteps != sequence_length:
1486
+ raise ValueError(
1487
+ f"Detected {timesteps} timesteps which does not match the configured sequence length "
1488
+ f"({sequence_length})."
1489
+ )
1490
+ sequences = arr.reshape(1, sequence_length, n_features)
1491
+ status, table, probs = predict_sequences(sequences)
1492
+ status = f"Single window prediction complete. {status}"
1493
+ return status, table, probs
1494
+
1495
+
1496
+ def predict_from_csv(
1497
+ file_obj, sequence_length: int, stride: int
1498
+ ) -> Tuple[str, pd.DataFrame, List[Dict[str, object]]]:
1499
+ df = load_measurement_csv(file_obj.name)
1500
+ sequences = dataframe_to_sequences(
1501
+ df,
1502
+ sequence_length=sequence_length,
1503
+ stride=stride,
1504
+ feature_columns=FEATURE_COLUMNS,
1505
+ )
1506
+ status, table, probs = predict_sequences(sequences)
1507
+ status = f"CSV processed successfully. Generated {len(sequences)} windows. {status}"
1508
+ return status, table, probs
1509
+
1510
+
1511
+ # --------------------------------------------------------------------------------------
1512
+ # Training helpers
1513
+ # --------------------------------------------------------------------------------------
1514
+
1515
+
1516
+ def classification_report_to_dataframe(report: Dict[str, Any]) -> pd.DataFrame:
1517
+ rows: List[Dict[str, Any]] = []
1518
+ for label, metrics in report.items():
1519
+ if isinstance(metrics, dict):
1520
+ row = {"label": label}
1521
+ for key, value in metrics.items():
1522
+ if key == "support":
1523
+ row[key] = int(value)
1524
+ else:
1525
+ row[key] = round(float(value), 4)
1526
+ rows.append(row)
1527
+ else:
1528
+ rows.append({"label": label, "accuracy": round(float(metrics), 4)})
1529
+ return pd.DataFrame(rows)
1530
+
1531
+
1532
+ def confusion_matrix_to_dataframe(
1533
+ confusion: Sequence[Sequence[float]], labels: Sequence[str]
1534
+ ) -> pd.DataFrame:
1535
+ if not confusion:
1536
+ return pd.DataFrame()
1537
+ df = pd.DataFrame(confusion, index=list(labels), columns=list(labels))
1538
+ df.index.name = "True Label"
1539
+ df.columns.name = "Predicted Label"
1540
+ return df
1541
+
1542
+
1543
+ # --------------------------------------------------------------------------------------
1544
+ # Gradio interface
1545
+ # --------------------------------------------------------------------------------------
1546
+
1547
+
1548
+ def build_interface() -> gr.Blocks:
1549
+ theme = gr.themes.Soft(
1550
+ primary_hue="sky", secondary_hue="blue", neutral_hue="gray"
1551
+ ).set(
1552
+ body_background_fill="#1f1f1f",
1553
+ body_text_color="#f5f5f5",
1554
+ block_background_fill="#262626",
1555
+ block_border_color="#333333",
1556
+ button_primary_background_fill="#5ac8fa",
1557
+ button_primary_background_fill_hover="#48b5eb",
1558
+ button_primary_border_color="#38bdf8",
1559
+ button_primary_text_color="#0f172a",
1560
+ button_secondary_background_fill="#3f3f46",
1561
+ button_secondary_text_color="#f5f5f5",
1562
+ )
1563
+
1564
+ def _normalise_directory_string(value: Optional[Union[str, Path]]) -> str:
1565
+ if value is None:
1566
+ return ""
1567
+ path = Path(value).expanduser()
1568
+ try:
1569
+ return str(path.resolve())
1570
+ except Exception:
1571
+ return str(path)
1572
+
1573
+ with gr.Blocks(
1574
+ title="Fault Classification - PMU Data", theme=theme, css=APP_CSS
1575
+ ) as demo:
1576
+ gr.Markdown("# Fault Classification for PMU & PV Data")
1577
+ gr.Markdown(
1578
+ "🖥️ TensorFlow is locked to CPU execution so the Space can run without CUDA drivers."
1579
+ )
1580
+ if MODEL is None or SCALER is None:
1581
+ gr.Markdown(
1582
+ "⚠️ **Artifacts Missing** — Upload `pmu_cnn_lstm_model.keras`, "
1583
+ "`pmu_feature_scaler.pkl`, and `pmu_metadata.json` to enable inference, "
1584
+ "or configure the Hugging Face Hub environment variables so they can be downloaded."
1585
+ )
1586
+ else:
1587
+ class_count = len(LABEL_CLASSES) if LABEL_CLASSES else "unknown"
1588
+ gr.Markdown(
1589
+ f"Loaded a **{MODEL_TYPE.upper()}** model ({MODEL_FORMAT.upper()}) with "
1590
+ f"{len(FEATURE_COLUMNS)} features, sequence length **{SEQUENCE_LENGTH}**, and "
1591
+ f"{class_count} target classes. Use the tabs below to run inference or fine-tune "
1592
+ "the model with your own CSV files."
1593
+ )
1594
+
1595
+ with gr.Accordion("Feature Reference", open=False):
1596
+ gr.Markdown(
1597
+ f"Each time window expects **{len(FEATURE_COLUMNS)} features** ordered as follows:\n"
1598
+ + "\n".join(f"- {name}" for name in FEATURE_COLUMNS)
1599
+ )
1600
+ gr.Markdown(
1601
+ f"Default training parameters: **sequence length = {SEQUENCE_LENGTH}**, "
1602
+ f"**stride = {DEFAULT_WINDOW_STRIDE}**. Adjust them in the tabs as needed."
1603
+ )
1604
+
1605
+ with gr.Tabs():
1606
+ with gr.Tab("Overview"):
1607
+ gr.Markdown(PROJECT_OVERVIEW_MD)
1608
+ with gr.Tab("Inference"):
1609
+ gr.Markdown("## Run Inference")
1610
+ with gr.Row():
1611
+ file_in = gr.File(label="Upload PMU CSV", file_types=[".csv"])
1612
+ text_in = gr.Textbox(
1613
+ lines=4,
1614
+ label="Or paste a single window (comma separated)",
1615
+ placeholder="49.97772,1.215825E-38,...",
1616
+ )
1617
+
1618
+ with gr.Row():
1619
+ sequence_length_input = gr.Slider(
1620
+ minimum=1,
1621
+ maximum=max(1, SEQUENCE_LENGTH * 2),
1622
+ step=1,
1623
+ value=SEQUENCE_LENGTH,
1624
+ label="Sequence length (timesteps)",
1625
+ )
1626
+ stride_input = gr.Slider(
1627
+ minimum=1,
1628
+ maximum=max(1, SEQUENCE_LENGTH),
1629
+ step=1,
1630
+ value=max(1, DEFAULT_WINDOW_STRIDE),
1631
+ label="CSV window stride",
1632
+ )
1633
+
1634
+ predict_btn = gr.Button("🚀 Run Inference", variant="primary")
1635
+ status_out = gr.Textbox(label="Status", interactive=False)
1636
+ table_out = gr.Dataframe(
1637
+ headers=["window", "predicted_label", "confidence", "top3"],
1638
+ label="Predictions",
1639
+ interactive=False,
1640
+ )
1641
+ probs_out = gr.JSON(label="Per-window probabilities")
1642
+
1643
+ def _run_prediction(file_obj, text, sequence_length, stride):
1644
+ sequence_length = int(sequence_length)
1645
+ stride = int(stride)
1646
+ try:
1647
+ if file_obj is not None:
1648
+ return predict_from_csv(file_obj, sequence_length, stride)
1649
+ if text and text.strip():
1650
+ return predict_from_text(text, sequence_length)
1651
+ return (
1652
+ "Please upload a CSV file or provide feature values.",
1653
+ pd.DataFrame(),
1654
+ [],
1655
+ )
1656
+ except Exception as exc:
1657
+ return f"Prediction failed: {exc}", pd.DataFrame(), []
1658
+
1659
+ predict_btn.click(
1660
+ _run_prediction,
1661
+ inputs=[file_in, text_in, sequence_length_input, stride_input],
1662
+ outputs=[status_out, table_out, probs_out],
1663
+ concurrency_limit=EVENT_CONCURRENCY_LIMIT,
1664
+ )
1665
+
1666
+ with gr.Tab("Training"):
1667
+ gr.Markdown("## Train or Fine-tune the Model")
1668
+ gr.Markdown(
1669
+ "Training data is automatically downloaded from the database. "
1670
+ "Refresh the cache if new files are added upstream."
1671
+ )
1672
+
1673
+ training_files_state = gr.State([])
1674
+ with gr.Row():
1675
+ with gr.Column(scale=3):
1676
+ training_files_summary = gr.Textbox(
1677
+ label="Database training CSVs",
1678
+ value="Training dataset not loaded yet.",
1679
+ lines=4,
1680
+ interactive=False,
1681
+ elem_id="training-files-summary",
1682
+ )
1683
+ with gr.Column(scale=2, min_width=240):
1684
+ dataset_info = gr.Markdown(
1685
+ "No local database CSVs downloaded yet.",
1686
+ )
1687
+ dataset_refresh = gr.Button(
1688
+ "🔄 Reload dataset from database",
1689
+ variant="secondary",
1690
+ )
1691
+ clear_cache_button = gr.Button(
1692
+ "🧹 Clear downloaded cache",
1693
+ variant="secondary",
1694
+ )
1695
+
1696
+ with gr.Accordion("📂 DataBaseBrowser", open=False):
1697
+ gr.Markdown(
1698
+ "Browse the upstream database by date and download only the CSVs you need."
1699
+ )
1700
+ with gr.Row(elem_id="date-browser-row"):
1701
+ with gr.Column(scale=1, elem_classes=["date-browser-column"]):
1702
+ year_selector = gr.Dropdown(label="Year", choices=[])
1703
+ year_download_button = gr.Button(
1704
+ "⬇️ Download year CSVs", variant="secondary"
1705
+ )
1706
+ with gr.Column(scale=1, elem_classes=["date-browser-column"]):
1707
+ month_selector = gr.Dropdown(label="Month", choices=[])
1708
+ month_download_button = gr.Button(
1709
+ "⬇️ Download month CSVs", variant="secondary"
1710
+ )
1711
+ with gr.Column(scale=1, elem_classes=["date-browser-column"]):
1712
+ day_selector = gr.Dropdown(label="Day", choices=[])
1713
+ day_download_button = gr.Button(
1714
+ "⬇️ Download day CSVs", variant="secondary"
1715
+ )
1716
+ with gr.Column(elem_id="available-files-section"):
1717
+ available_files = gr.CheckboxGroup(
1718
+ label="Available CSV files",
1719
+ choices=[],
1720
+ value=[],
1721
+ elem_id="available-files-grid",
1722
+ )
1723
+ download_button = gr.Button(
1724
+ "⬇️ Download selected CSVs",
1725
+ variant="secondary",
1726
+ elem_id="download-selected-button",
1727
+ )
1728
+ repo_status = gr.Markdown(
1729
+ "Click 'Reload dataset from database' to fetch the directory tree."
1730
+ )
1731
+
1732
+ with gr.Row():
1733
+ label_input = gr.Dropdown(
1734
+ value=LABEL_COLUMN,
1735
+ choices=[LABEL_COLUMN],
1736
+ allow_custom_value=True,
1737
+ label="Label column name",
1738
+ )
1739
+ model_selector = gr.Radio(
1740
+ choices=["CNN-LSTM", "TCN", "SVM"],
1741
+ value=(
1742
+ "TCN"
1743
+ if MODEL_TYPE == "tcn"
1744
+ else ("SVM" if MODEL_TYPE == "svm" else "CNN-LSTM")
1745
+ ),
1746
+ label="Model architecture",
1747
+ )
1748
+ sequence_length_train = gr.Slider(
1749
+ minimum=4,
1750
+ maximum=max(32, SEQUENCE_LENGTH * 2),
1751
+ step=1,
1752
+ value=SEQUENCE_LENGTH,
1753
+ label="Sequence length",
1754
+ )
1755
+ stride_train = gr.Slider(
1756
+ minimum=1,
1757
+ maximum=max(32, SEQUENCE_LENGTH * 2),
1758
+ step=1,
1759
+ value=max(1, DEFAULT_WINDOW_STRIDE),
1760
+ label="Stride",
1761
+ )
1762
+
1763
+ model_default = MODEL_FILENAME_BY_TYPE.get(
1764
+ MODEL_TYPE, Path(LOCAL_MODEL_FILE).name
1765
+ )
1766
+
1767
+ with gr.Row():
1768
+ validation_train = gr.Slider(
1769
+ minimum=0.05,
1770
+ maximum=0.4,
1771
+ step=0.05,
1772
+ value=0.2,
1773
+ label="Validation split",
1774
+ )
1775
+ batch_train = gr.Slider(
1776
+ minimum=32,
1777
+ maximum=512,
1778
+ step=32,
1779
+ value=128,
1780
+ label="Batch size",
1781
+ )
1782
+ epochs_train = gr.Slider(
1783
+ minimum=5,
1784
+ maximum=100,
1785
+ step=5,
1786
+ value=50,
1787
+ label="Epochs",
1788
+ )
1789
+
1790
+ directory_choices, directory_default = gather_directory_choices(
1791
+ str(MODEL_OUTPUT_DIR)
1792
+ )
1793
+ artifact_choices, default_artifact = gather_artifact_choices(
1794
+ directory_default
1795
+ )
1796
+
1797
+ with gr.Row():
1798
+ output_directory = gr.Dropdown(
1799
+ value=directory_default,
1800
+ label="Output directory",
1801
+ choices=directory_choices,
1802
+ allow_custom_value=True,
1803
+ )
1804
+ model_name = gr.Textbox(
1805
+ value=model_default,
1806
+ label="Model output filename",
1807
+ )
1808
+ scaler_name = gr.Textbox(
1809
+ value=Path(LOCAL_SCALER_FILE).name,
1810
+ label="Scaler output filename",
1811
+ )
1812
+ metadata_name = gr.Textbox(
1813
+ value=Path(LOCAL_METADATA_FILE).name,
1814
+ label="Metadata output filename",
1815
+ )
1816
+
1817
+ with gr.Row():
1818
+ artifact_browser = gr.Dropdown(
1819
+ label="Saved artifacts in directory",
1820
+ choices=artifact_choices,
1821
+ value=default_artifact,
1822
+ )
1823
+ artifact_download_button = gr.DownloadButton(
1824
+ "⬇️ Download selected artifact",
1825
+ value=default_artifact,
1826
+ visible=bool(default_artifact),
1827
+ variant="secondary",
1828
+ )
1829
+
1830
+ def on_output_directory_change(selected_dir, current_selection):
1831
+ choices, normalised = gather_directory_choices(selected_dir)
1832
+ artifact_options, selected = gather_artifact_choices(
1833
+ normalised, current_selection
1834
+ )
1835
+ return (
1836
+ gr.update(choices=choices, value=normalised),
1837
+ gr.update(choices=artifact_options, value=selected),
1838
+ download_button_state(selected),
1839
+ )
1840
+
1841
+ def on_artifact_change(selected_path):
1842
+ return download_button_state(selected_path)
1843
+
1844
+ output_directory.change(
1845
+ on_output_directory_change,
1846
+ inputs=[output_directory, artifact_browser],
1847
+ outputs=[
1848
+ output_directory,
1849
+ artifact_browser,
1850
+ artifact_download_button,
1851
+ ],
1852
+ concurrency_limit=EVENT_CONCURRENCY_LIMIT,
1853
+ )
1854
+
1855
+ artifact_browser.change(
1856
+ on_artifact_change,
1857
+ inputs=[artifact_browser],
1858
+ outputs=[artifact_download_button],
1859
+ concurrency_limit=EVENT_CONCURRENCY_LIMIT,
1860
+ )
1861
+
1862
+ with gr.Row(elem_id="artifact-download-row"):
1863
+ model_download_button = gr.DownloadButton(
1864
+ "⬇️ Download model file",
1865
+ value=None,
1866
+ visible=False,
1867
+ elem_classes=["artifact-download-button"],
1868
+ )
1869
+ scaler_download_button = gr.DownloadButton(
1870
+ "⬇️ Download scaler file",
1871
+ value=None,
1872
+ visible=False,
1873
+ elem_classes=["artifact-download-button"],
1874
+ )
1875
+ metadata_download_button = gr.DownloadButton(
1876
+ "⬇️ Download metadata file",
1877
+ value=None,
1878
+ visible=False,
1879
+ elem_classes=["artifact-download-button"],
1880
+ )
1881
+ tensorboard_download_button = gr.DownloadButton(
1882
+ "⬇️ Download TensorBoard logs",
1883
+ value=None,
1884
+ visible=False,
1885
+ elem_classes=["artifact-download-button"],
1886
+ )
1887
+
1888
+ model_download_button.file_name = Path(LOCAL_MODEL_FILE).name
1889
+ scaler_download_button.file_name = Path(LOCAL_SCALER_FILE).name
1890
+ metadata_download_button.file_name = Path(LOCAL_METADATA_FILE).name
1891
+ tensorboard_download_button.file_name = "tensorboard_logs.zip"
1892
+
1893
+ tensorboard_toggle = gr.Checkbox(
1894
+ value=True,
1895
+ label="Enable TensorBoard logging (creates downloadable archive)",
1896
+ )
1897
+
1898
+ def _suggest_model_filename(choice: str, current_value: str):
1899
+ choice_key = (choice or "cnn_lstm").lower().replace("-", "_")
1900
+ suggested = MODEL_FILENAME_BY_TYPE.get(
1901
+ choice_key, Path(LOCAL_MODEL_FILE).name
1902
+ )
1903
+ known_defaults = set(MODEL_FILENAME_BY_TYPE.values())
1904
+ current_name = Path(current_value).name if current_value else ""
1905
+ if current_name and current_name not in known_defaults:
1906
+ return gr.update()
1907
+ return gr.update(value=suggested)
1908
+
1909
+ model_selector.change(
1910
+ _suggest_model_filename,
1911
+ inputs=[model_selector, model_name],
1912
+ outputs=model_name,
1913
+ )
1914
+
1915
+ with gr.Row():
1916
+ train_button = gr.Button("🛠️ Start Training", variant="primary")
1917
+ progress_button = gr.Button(
1918
+ "📊 Check Progress", variant="secondary"
1919
+ )
1920
+
1921
+ # Training status display
1922
+ training_status = gr.Textbox(label="Training Status", interactive=False)
1923
+ report_output = gr.Dataframe(
1924
+ label="Classification report", interactive=False
1925
+ )
1926
+ history_output = gr.JSON(label="Training history")
1927
+ confusion_output = gr.Dataframe(
1928
+ label="Confusion matrix", interactive=False
1929
+ )
1930
+
1931
+ # Message area at the bottom for progress updates
1932
+ with gr.Accordion("📋 Progress Messages", open=True):
1933
+ progress_messages = gr.Textbox(
1934
+ label="Training Messages",
1935
+ lines=8,
1936
+ max_lines=20,
1937
+ interactive=False,
1938
+ autoscroll=True,
1939
+ placeholder="Click 'Check Progress' to see training updates...",
1940
+ )
1941
+ with gr.Row():
1942
+ gr.Button("🗑️ Clear Messages", variant="secondary").click(
1943
+ lambda: "", outputs=[progress_messages]
1944
+ )
1945
+
1946
+ def _run_training(
1947
+ file_paths,
1948
+ label_column,
1949
+ model_choice,
1950
+ sequence_length,
1951
+ stride,
1952
+ validation_split,
1953
+ batch_size,
1954
+ epochs,
1955
+ output_dir,
1956
+ model_filename,
1957
+ scaler_filename,
1958
+ metadata_filename,
1959
+ enable_tensorboard,
1960
+ ):
1961
+ base_dir = normalise_output_directory(output_dir)
1962
+ try:
1963
+ base_dir.mkdir(parents=True, exist_ok=True)
1964
+
1965
+ model_path = resolve_output_path(
1966
+ base_dir,
1967
+ model_filename,
1968
+ Path(LOCAL_MODEL_FILE).name,
1969
+ )
1970
+ scaler_path = resolve_output_path(
1971
+ base_dir,
1972
+ scaler_filename,
1973
+ Path(LOCAL_SCALER_FILE).name,
1974
+ )
1975
+ metadata_path = resolve_output_path(
1976
+ base_dir,
1977
+ metadata_filename,
1978
+ Path(LOCAL_METADATA_FILE).name,
1979
+ )
1980
+
1981
+ model_path.parent.mkdir(parents=True, exist_ok=True)
1982
+ scaler_path.parent.mkdir(parents=True, exist_ok=True)
1983
+ metadata_path.parent.mkdir(parents=True, exist_ok=True)
1984
+
1985
+ # Create status file path for progress tracking
1986
+ status_file = model_path.parent / "training_status.txt"
1987
+
1988
+ # Initialize status
1989
+ with open(status_file, "w") as f:
1990
+ f.write("Starting training setup...")
1991
+
1992
+ if not file_paths:
1993
+ raise ValueError(
1994
+ "No training CSVs were found in the database cache. "
1995
+ "Use 'Reload dataset from database' and try again."
1996
+ )
1997
+
1998
+ with open(status_file, "w") as f:
1999
+ f.write("Loading and validating CSV files...")
2000
+
2001
+ available_paths = [
2002
+ path for path in file_paths if Path(path).exists()
2003
+ ]
2004
+ missing_paths = [
2005
+ Path(path).name
2006
+ for path in file_paths
2007
+ if not Path(path).exists()
2008
+ ]
2009
+ if not available_paths:
2010
+ raise ValueError(
2011
+ "Database training dataset is unavailable. Reload the dataset and retry."
2012
+ )
2013
+
2014
+ dfs = [load_measurement_csv(path) for path in available_paths]
2015
+ combined = pd.concat(dfs, ignore_index=True)
2016
+
2017
+ # Validate data size and provide recommendations
2018
+ total_samples = len(combined)
2019
+ if total_samples < 100:
2020
+ print(
2021
+ f"Warning: Only {total_samples} samples. Recommend at least 1000 for good results."
2022
+ )
2023
+ print(
2024
+ "Automatically switching to SVM for small dataset compatibility."
2025
+ )
2026
+ if model_choice in ["cnn_lstm", "tcn"]:
2027
+ model_choice = "svm"
2028
+ print(
2029
+ f"Model type changed to SVM for better small dataset performance."
2030
+ )
2031
+ if total_samples < 10:
2032
+ raise ValueError(
2033
+ f"Insufficient data: {total_samples} samples. Need at least 10 samples for training."
2034
+ )
2035
+
2036
+ label_column = (label_column or LABEL_COLUMN).strip()
2037
+ if not label_column:
2038
+ raise ValueError("Label column name cannot be empty.")
2039
+
2040
+ model_choice = (
2041
+ (model_choice or "CNN-LSTM").lower().replace("-", "_")
2042
+ )
2043
+ if model_choice not in {"cnn_lstm", "tcn", "svm"}:
2044
+ raise ValueError(
2045
+ "Select CNN-LSTM, TCN, or SVM for the model architecture."
2046
+ )
2047
+
2048
+ with open(status_file, "w") as f:
2049
+ f.write(
2050
+ f"Starting {model_choice.upper()} training with {len(combined)} samples..."
2051
+ )
2052
+
2053
+ # Start training
2054
+ result = train_from_dataframe(
2055
+ combined,
2056
+ label_column=label_column,
2057
+ feature_columns=None,
2058
+ sequence_length=int(sequence_length),
2059
+ stride=int(stride),
2060
+ validation_split=float(validation_split),
2061
+ batch_size=int(batch_size),
2062
+ epochs=int(epochs),
2063
+ model_type=model_choice,
2064
+ model_path=model_path,
2065
+ scaler_path=scaler_path,
2066
+ metadata_path=metadata_path,
2067
+ enable_tensorboard=bool(enable_tensorboard),
2068
+ )
2069
+
2070
+ refresh_artifacts(
2071
+ Path(result["model_path"]),
2072
+ Path(result["scaler_path"]),
2073
+ Path(result["metadata_path"]),
2074
+ )
2075
+
2076
+ report_df = classification_report_to_dataframe(
2077
+ result["classification_report"]
2078
+ )
2079
+ confusion_df = confusion_matrix_to_dataframe(
2080
+ result["confusion_matrix"], result["class_names"]
2081
+ )
2082
+ tensorboard_dir = result.get("tensorboard_log_dir")
2083
+ tensorboard_zip = result.get("tensorboard_zip_path")
2084
+
2085
+ architecture = result["model_type"].replace("_", "-").upper()
2086
+ status = (
2087
+ f"Training complete using a {architecture} architecture. "
2088
+ f"{result['num_sequences']} windows derived from "
2089
+ f"{result['num_samples']} rows across {len(available_paths)} file(s)."
2090
+ f" Artifacts saved to:"
2091
+ f"\n• Model: {result['model_path']}\n"
2092
+ f"• Scaler: {result['scaler_path']}\n"
2093
+ f"• Metadata: {result['metadata_path']}"
2094
+ )
2095
+
2096
+ status += f"\nLabel column used: {result.get('label_column', label_column)}"
2097
+
2098
+ if tensorboard_dir:
2099
+ status += (
2100
+ f"\nTensorBoard logs directory: {tensorboard_dir}"
2101
+ f'\nRun `tensorboard --logdir "{tensorboard_dir}"` to inspect the training curves.'
2102
+ "\nDownload the archive below to explore the run offline."
2103
+ )
2104
+
2105
+ if missing_paths:
2106
+ skipped = ", ".join(missing_paths)
2107
+ status = f"⚠️ Skipped missing files: {skipped}\n" + status
2108
+
2109
+ artifact_choices, selected_artifact = gather_artifact_choices(
2110
+ str(base_dir), result["model_path"]
2111
+ )
2112
+
2113
+ return (
2114
+ status,
2115
+ report_df,
2116
+ result["history"],
2117
+ confusion_df,
2118
+ download_button_state(result["model_path"]),
2119
+ download_button_state(result["scaler_path"]),
2120
+ download_button_state(result["metadata_path"]),
2121
+ download_button_state(tensorboard_zip),
2122
+ gr.update(value=result.get("label_column", label_column)),
2123
+ gr.update(
2124
+ choices=artifact_choices, value=selected_artifact
2125
+ ),
2126
+ download_button_state(selected_artifact),
2127
+ )
2128
+ except Exception as exc:
2129
+ artifact_choices, selected_artifact = gather_artifact_choices(
2130
+ str(base_dir)
2131
+ )
2132
+ return (
2133
+ f"Training failed: {exc}",
2134
+ pd.DataFrame(),
2135
+ {},
2136
+ pd.DataFrame(),
2137
+ download_button_state(None),
2138
+ download_button_state(None),
2139
+ download_button_state(None),
2140
+ download_button_state(None),
2141
+ gr.update(),
2142
+ gr.update(
2143
+ choices=artifact_choices, value=selected_artifact
2144
+ ),
2145
+ download_button_state(selected_artifact),
2146
+ )
2147
+
2148
+ def _check_progress(output_dir, model_filename, current_messages):
2149
+ """Check training progress by reading status file and accumulate messages."""
2150
+ model_path = resolve_output_path(
2151
+ output_dir, model_filename, Path(LOCAL_MODEL_FILE).name
2152
+ )
2153
+ status_file = model_path.parent / "training_status.txt"
2154
+ status_message = read_training_status(str(status_file))
2155
+
2156
+ # Add timestamp to the message
2157
+ from datetime import datetime
2158
+
2159
+ timestamp = datetime.now().strftime("%H:%M:%S")
2160
+ new_message = f"[{timestamp}] {status_message}"
2161
+
2162
+ # Accumulate messages, keeping last 50 lines to prevent overflow
2163
+ if current_messages:
2164
+ lines = current_messages.split("\n")
2165
+ lines.append(new_message)
2166
+ # Keep only last 50 lines
2167
+ if len(lines) > 50:
2168
+ lines = lines[-50:]
2169
+ accumulated_messages = "\n".join(lines)
2170
+ else:
2171
+ accumulated_messages = new_message
2172
+
2173
+ return accumulated_messages
2174
+
2175
+ train_button.click(
2176
+ _run_training,
2177
+ inputs=[
2178
+ training_files_state,
2179
+ label_input,
2180
+ model_selector,
2181
+ sequence_length_train,
2182
+ stride_train,
2183
+ validation_train,
2184
+ batch_train,
2185
+ epochs_train,
2186
+ output_directory,
2187
+ model_name,
2188
+ scaler_name,
2189
+ metadata_name,
2190
+ tensorboard_toggle,
2191
+ ],
2192
+ outputs=[
2193
+ training_status,
2194
+ report_output,
2195
+ history_output,
2196
+ confusion_output,
2197
+ model_download_button,
2198
+ scaler_download_button,
2199
+ metadata_download_button,
2200
+ tensorboard_download_button,
2201
+ label_input,
2202
+ artifact_browser,
2203
+ artifact_download_button,
2204
+ ],
2205
+ concurrency_limit=EVENT_CONCURRENCY_LIMIT,
2206
+ )
2207
+
2208
+ progress_button.click(
2209
+ _check_progress,
2210
+ inputs=[output_directory, model_name, progress_messages],
2211
+ outputs=[progress_messages],
2212
+ )
2213
+
2214
+ year_selector.change(
2215
+ on_year_change,
2216
+ inputs=[year_selector],
2217
+ outputs=[
2218
+ month_selector,
2219
+ day_selector,
2220
+ available_files,
2221
+ repo_status,
2222
+ ],
2223
+ concurrency_limit=EVENT_CONCURRENCY_LIMIT,
2224
+ )
2225
+
2226
+ month_selector.change(
2227
+ on_month_change,
2228
+ inputs=[year_selector, month_selector],
2229
+ outputs=[day_selector, available_files, repo_status],
2230
+ concurrency_limit=EVENT_CONCURRENCY_LIMIT,
2231
+ )
2232
+
2233
+ day_selector.change(
2234
+ on_day_change,
2235
+ inputs=[year_selector, month_selector, day_selector],
2236
+ outputs=[available_files, repo_status],
2237
+ concurrency_limit=EVENT_CONCURRENCY_LIMIT,
2238
+ )
2239
+
2240
+ download_button.click(
2241
+ download_selected_files,
2242
+ inputs=[
2243
+ year_selector,
2244
+ month_selector,
2245
+ day_selector,
2246
+ available_files,
2247
+ label_input,
2248
+ ],
2249
+ outputs=[
2250
+ training_files_state,
2251
+ training_files_summary,
2252
+ label_input,
2253
+ dataset_info,
2254
+ available_files,
2255
+ repo_status,
2256
+ ],
2257
+ concurrency_limit=EVENT_CONCURRENCY_LIMIT,
2258
+ )
2259
+
2260
+ year_download_button.click(
2261
+ download_year_bundle,
2262
+ inputs=[year_selector, label_input],
2263
+ outputs=[
2264
+ training_files_state,
2265
+ training_files_summary,
2266
+ label_input,
2267
+ dataset_info,
2268
+ available_files,
2269
+ repo_status,
2270
+ ],
2271
+ concurrency_limit=EVENT_CONCURRENCY_LIMIT,
2272
+ )
2273
+
2274
+ month_download_button.click(
2275
+ download_month_bundle,
2276
+ inputs=[year_selector, month_selector, label_input],
2277
+ outputs=[
2278
+ training_files_state,
2279
+ training_files_summary,
2280
+ label_input,
2281
+ dataset_info,
2282
+ available_files,
2283
+ repo_status,
2284
+ ],
2285
+ concurrency_limit=EVENT_CONCURRENCY_LIMIT,
2286
+ )
2287
+
2288
+ day_download_button.click(
2289
+ download_day_bundle,
2290
+ inputs=[year_selector, month_selector, day_selector, label_input],
2291
+ outputs=[
2292
+ training_files_state,
2293
+ training_files_summary,
2294
+ label_input,
2295
+ dataset_info,
2296
+ available_files,
2297
+ repo_status,
2298
+ ],
2299
+ concurrency_limit=EVENT_CONCURRENCY_LIMIT,
2300
+ )
2301
+
2302
+ def _reload_dataset(current_label):
2303
+ local = load_repository_training_files(
2304
+ current_label, force_refresh=True
2305
+ )
2306
+ remote = refresh_remote_browser(force_refresh=True)
2307
+ return (*local, *remote)
2308
+
2309
+ dataset_refresh.click(
2310
+ _reload_dataset,
2311
+ inputs=[label_input],
2312
+ outputs=[
2313
+ training_files_state,
2314
+ training_files_summary,
2315
+ label_input,
2316
+ dataset_info,
2317
+ year_selector,
2318
+ month_selector,
2319
+ day_selector,
2320
+ available_files,
2321
+ repo_status,
2322
+ ],
2323
+ concurrency_limit=EVENT_CONCURRENCY_LIMIT,
2324
+ )
2325
+
2326
+ clear_cache_button.click(
2327
+ clear_downloaded_cache,
2328
+ inputs=[label_input],
2329
+ outputs=[
2330
+ training_files_state,
2331
+ training_files_summary,
2332
+ label_input,
2333
+ dataset_info,
2334
+ year_selector,
2335
+ month_selector,
2336
+ day_selector,
2337
+ available_files,
2338
+ repo_status,
2339
+ ],
2340
+ concurrency_limit=EVENT_CONCURRENCY_LIMIT,
2341
+ )
2342
+
2343
+ def _initialise_dataset():
2344
+ local = load_repository_training_files(
2345
+ LABEL_COLUMN, force_refresh=False
2346
+ )
2347
+ remote = refresh_remote_browser(force_refresh=False)
2348
+ return (*local, *remote)
2349
+
2350
+ demo.load(
2351
+ _initialise_dataset,
2352
+ inputs=None,
2353
+ outputs=[
2354
+ training_files_state,
2355
+ training_files_summary,
2356
+ label_input,
2357
+ dataset_info,
2358
+ year_selector,
2359
+ month_selector,
2360
+ day_selector,
2361
+ available_files,
2362
+ repo_status,
2363
+ ],
2364
+ queue=False,
2365
+ )
2366
+
2367
+ return demo
2368
+
2369
+
2370
+ # --------------------------------------------------------------------------------------
2371
+ # Launch helpers
2372
+ # --------------------------------------------------------------------------------------
2373
+
2374
+
2375
+ def resolve_server_port() -> int:
2376
+ for env_var in ("PORT", "GRADIO_SERVER_PORT"):
2377
+ value = os.environ.get(env_var)
2378
+ if value:
2379
+ try:
2380
+ return int(value)
2381
+ except ValueError:
2382
+ print(f"Ignoring invalid port value from {env_var}: {value}")
2383
+ return 7860
2384
+
2385
+
2386
+ def main():
2387
+ print("Building Gradio interface...")
2388
+ try:
2389
+ demo = build_interface()
2390
+ print("Interface built successfully")
2391
+ except Exception as e:
2392
+ print(f"Failed to build interface: {e}")
2393
+ import traceback
2394
+
2395
+ traceback.print_exc()
2396
+ return
2397
+
2398
+ print("Setting up queue...")
2399
+ try:
2400
+ demo.queue(max_size=QUEUE_MAX_SIZE)
2401
+ print("Queue configured")
2402
+ except Exception as e:
2403
+ print(f"Failed to configure queue: {e}")
2404
+
2405
+ try:
2406
+ port = resolve_server_port()
2407
+ print(f"Launching Gradio app on port {port}")
2408
+ demo.launch(server_name="0.0.0.0", server_port=port, show_error=True)
2409
+ except OSError as exc:
2410
+ print("Failed to launch on requested port:", exc)
2411
+ try:
2412
+ demo.launch(server_name="0.0.0.0", show_error=True)
2413
+ except Exception as e:
2414
+ print(f"Failed to launch completely: {e}")
2415
+ except Exception as e:
2416
+ print(f"Unexpected launch error: {e}")
2417
+ import traceback
2418
+
2419
+ traceback.print_exc()
2420
+
2421
+
2422
+ if __name__ == "__main__":
2423
+ print("=" * 50)
2424
+ print("PMU Fault Classification App Starting")
2425
+ print(f"Python version: {os.sys.version}")
2426
+ print(f"Working directory: {os.getcwd()}")
2427
+ print(f"HUB_REPO: {HUB_REPO}")
2428
+ print(f"Model available: {MODEL is not None}")
2429
+ print(f"Scaler available: {SCALER is not None}")
2430
+ print("=" * 50)
2431
+ main()
.history/app_20251009232414.py ADDED
@@ -0,0 +1,2431 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Gradio front-end for Fault_Classification_PMU_Data models.
2
+
3
+ The application loads a CNN-LSTM model (and accompanying scaler/metadata)
4
+ produced by ``fault_classification_pmu.py`` and exposes a streamlined
5
+ prediction interface optimised for Hugging Face Spaces deployment. It supports
6
+ raw PMU time-series CSV uploads as well as manual comma separated feature
7
+ vectors.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import json
13
+ import os
14
+ import shutil
15
+
16
+ os.environ.setdefault("CUDA_VISIBLE_DEVICES", "-1")
17
+ os.environ.setdefault("TF_CPP_MIN_LOG_LEVEL", "2")
18
+ os.environ.setdefault("TF_ENABLE_ONEDNN_OPTS", "0")
19
+
20
+ import re
21
+ from pathlib import Path
22
+ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
23
+
24
+ import gradio as gr
25
+ import joblib
26
+ import numpy as np
27
+ import pandas as pd
28
+ import requests
29
+ from huggingface_hub import hf_hub_download
30
+ from tensorflow.keras.models import load_model
31
+
32
+ from fault_classification_pmu import (
33
+ DEFAULT_FEATURE_COLUMNS as TRAINING_DEFAULT_FEATURE_COLUMNS,
34
+ LABEL_GUESS_CANDIDATES as TRAINING_LABEL_GUESSES,
35
+ train_from_dataframe,
36
+ )
37
+
38
+ # --------------------------------------------------------------------------------------
39
+ # Configuration
40
+ # --------------------------------------------------------------------------------------
41
+ DEFAULT_FEATURE_COLUMNS: List[str] = list(TRAINING_DEFAULT_FEATURE_COLUMNS)
42
+ DEFAULT_SEQUENCE_LENGTH = 32
43
+ DEFAULT_STRIDE = 4
44
+
45
+ LOCAL_MODEL_FILE = os.environ.get("PMU_MODEL_FILE", "pmu_cnn_lstm_model.keras")
46
+ LOCAL_SCALER_FILE = os.environ.get("PMU_SCALER_FILE", "pmu_feature_scaler.pkl")
47
+ LOCAL_METADATA_FILE = os.environ.get("PMU_METADATA_FILE", "pmu_metadata.json")
48
+
49
+ MODEL_OUTPUT_DIR = Path(os.environ.get("PMU_MODEL_DIR", "model")).resolve()
50
+ MODEL_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
51
+
52
+ HUB_REPO = os.environ.get("PMU_HUB_REPO", "")
53
+ HUB_MODEL_FILENAME = os.environ.get("PMU_HUB_MODEL_FILENAME", LOCAL_MODEL_FILE)
54
+ HUB_SCALER_FILENAME = os.environ.get("PMU_HUB_SCALER_FILENAME", LOCAL_SCALER_FILE)
55
+ HUB_METADATA_FILENAME = os.environ.get("PMU_HUB_METADATA_FILENAME", LOCAL_METADATA_FILE)
56
+
57
+ ENV_MODEL_PATH = "PMU_MODEL_PATH"
58
+ ENV_SCALER_PATH = "PMU_SCALER_PATH"
59
+ ENV_METADATA_PATH = "PMU_METADATA_PATH"
60
+
61
+ # --------------------------------------------------------------------------------------
62
+ # Utility functions for loading artifacts
63
+ # --------------------------------------------------------------------------------------
64
+
65
+
66
+ def download_from_hub(filename: str) -> Optional[Path]:
67
+ if not HUB_REPO or not filename:
68
+ return None
69
+ try:
70
+ print(f"Downloading {filename} from {HUB_REPO} ...")
71
+ # Add timeout to prevent hanging
72
+ path = hf_hub_download(repo_id=HUB_REPO, filename=filename)
73
+ print("Downloaded", path)
74
+ return Path(path)
75
+ except Exception as exc: # pragma: no cover - logging convenience
76
+ print("Failed to download", filename, "from", HUB_REPO, ":", exc)
77
+ print("Continuing without pre-trained model...")
78
+ return None
79
+
80
+
81
+ def resolve_artifact(
82
+ local_name: str, env_var: str, hub_filename: str
83
+ ) -> Optional[Path]:
84
+ print(f"Resolving artifact: {local_name}, env: {env_var}, hub: {hub_filename}")
85
+ candidates = [Path(local_name)] if local_name else []
86
+ if local_name:
87
+ candidates.append(MODEL_OUTPUT_DIR / Path(local_name).name)
88
+ env_value = os.environ.get(env_var)
89
+ if env_value:
90
+ candidates.append(Path(env_value))
91
+
92
+ for candidate in candidates:
93
+ if candidate and candidate.exists():
94
+ print(f"Found local artifact: {candidate}")
95
+ return candidate
96
+
97
+ print(f"No local artifacts found, checking hub...")
98
+ # Only try to download if we have a hub repo configured
99
+ if HUB_REPO:
100
+ return download_from_hub(hub_filename)
101
+ else:
102
+ print("No HUB_REPO configured, skipping download")
103
+ return None
104
+
105
+
106
+ def load_metadata(path: Optional[Path]) -> Dict:
107
+ if path and path.exists():
108
+ try:
109
+ return json.loads(path.read_text())
110
+ except Exception as exc: # pragma: no cover - metadata parsing errors
111
+ print("Failed to read metadata", path, exc)
112
+ return {}
113
+
114
+
115
+ def try_load_scaler(path: Optional[Path]):
116
+ if not path:
117
+ return None
118
+ try:
119
+ scaler = joblib.load(path)
120
+ print("Loaded scaler from", path)
121
+ return scaler
122
+ except Exception as exc:
123
+ print("Failed to load scaler", path, exc)
124
+ return None
125
+
126
+
127
+ # Initialize paths with error handling
128
+ print("Starting application initialization...")
129
+ try:
130
+ MODEL_PATH = resolve_artifact(LOCAL_MODEL_FILE, ENV_MODEL_PATH, HUB_MODEL_FILENAME)
131
+ print(f"Model path resolved: {MODEL_PATH}")
132
+ except Exception as e:
133
+ print(f"Model path resolution failed: {e}")
134
+ MODEL_PATH = None
135
+
136
+ try:
137
+ SCALER_PATH = resolve_artifact(
138
+ LOCAL_SCALER_FILE, ENV_SCALER_PATH, HUB_SCALER_FILENAME
139
+ )
140
+ print(f"Scaler path resolved: {SCALER_PATH}")
141
+ except Exception as e:
142
+ print(f"Scaler path resolution failed: {e}")
143
+ SCALER_PATH = None
144
+
145
+ try:
146
+ METADATA_PATH = resolve_artifact(
147
+ LOCAL_METADATA_FILE, ENV_METADATA_PATH, HUB_METADATA_FILENAME
148
+ )
149
+ print(f"Metadata path resolved: {METADATA_PATH}")
150
+ except Exception as e:
151
+ print(f"Metadata path resolution failed: {e}")
152
+ METADATA_PATH = None
153
+
154
+ try:
155
+ METADATA = load_metadata(METADATA_PATH)
156
+ print(f"Metadata loaded: {len(METADATA)} entries")
157
+ except Exception as e:
158
+ print(f"Metadata loading failed: {e}")
159
+ METADATA = {}
160
+
161
+ # Queuing configuration
162
+ QUEUE_MAX_SIZE = 32
163
+ # Apply a small per-event concurrency limit to avoid relying on the deprecated
164
+ # ``concurrency_count`` parameter when enabling Gradio's request queue.
165
+ EVENT_CONCURRENCY_LIMIT = 2
166
+
167
+
168
+ def try_load_model(path: Optional[Path], model_type: str, model_format: str):
169
+ if not path:
170
+ return None
171
+ try:
172
+ if model_type == "svm" or model_format == "joblib":
173
+ model = joblib.load(path)
174
+ else:
175
+ model = load_model(path)
176
+ print("Loaded model from", path)
177
+ return model
178
+ except Exception as exc: # pragma: no cover - runtime diagnostics
179
+ print("Failed to load model", path, exc)
180
+ return None
181
+
182
+
183
+ FEATURE_COLUMNS: List[str] = list(DEFAULT_FEATURE_COLUMNS)
184
+ LABEL_CLASSES: List[str] = []
185
+ LABEL_COLUMN: str = "Fault"
186
+ SEQUENCE_LENGTH: int = DEFAULT_SEQUENCE_LENGTH
187
+ DEFAULT_WINDOW_STRIDE: int = DEFAULT_STRIDE
188
+ MODEL_TYPE: str = "cnn_lstm"
189
+ MODEL_FORMAT: str = "keras"
190
+
191
+
192
+ def _model_output_path(filename: str) -> str:
193
+ return str(MODEL_OUTPUT_DIR / Path(filename).name)
194
+
195
+
196
+ MODEL_FILENAME_BY_TYPE: Dict[str, str] = {
197
+ "cnn_lstm": Path(LOCAL_MODEL_FILE).name,
198
+ "tcn": "pmu_tcn_model.keras",
199
+ "svm": "pmu_svm_model.joblib",
200
+ }
201
+
202
+ REQUIRED_PMU_COLUMNS: Tuple[str, ...] = tuple(DEFAULT_FEATURE_COLUMNS)
203
+ TRAINING_UPLOAD_DIR = Path(
204
+ os.environ.get("PMU_TRAINING_UPLOAD_DIR", "training_uploads")
205
+ )
206
+ TRAINING_UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
207
+
208
+ TRAINING_DATA_REPO = os.environ.get(
209
+ "PMU_TRAINING_DATA_REPO", "VincentCroft/ThesisModelData"
210
+ )
211
+ TRAINING_DATA_BRANCH = os.environ.get("PMU_TRAINING_DATA_BRANCH", "main")
212
+ TRAINING_DATA_DIR = Path(os.environ.get("PMU_TRAINING_DATA_DIR", "training_dataset"))
213
+ TRAINING_DATA_DIR.mkdir(parents=True, exist_ok=True)
214
+
215
+ GITHUB_CONTENT_CACHE: Dict[str, List[Dict[str, Any]]] = {}
216
+
217
+
218
+ APP_CSS = """
219
+ #available-files-section {
220
+ position: relative;
221
+ display: flex;
222
+ flex-direction: column;
223
+ gap: 0.75rem;
224
+ border-radius: 0.75rem;
225
+ isolation: isolate;
226
+ }
227
+
228
+ #available-files-grid {
229
+ position: static;
230
+ overflow: visible;
231
+ }
232
+
233
+ #available-files-grid .form {
234
+ position: static;
235
+ min-height: 16rem;
236
+ }
237
+
238
+ #available-files-grid .wrap {
239
+ display: grid;
240
+ grid-template-columns: repeat(4, minmax(0, 1fr));
241
+ gap: 0.5rem;
242
+ max-height: 24rem;
243
+ min-height: 16rem;
244
+ overflow-y: auto;
245
+ padding-right: 0.25rem;
246
+ }
247
+
248
+ #available-files-grid .wrap > div {
249
+ min-width: 0;
250
+ }
251
+
252
+ #available-files-grid .wrap label {
253
+ margin: 0;
254
+ display: flex;
255
+ align-items: center;
256
+ padding: 0.45rem 0.65rem;
257
+ border-radius: 0.65rem;
258
+ background-color: rgba(255, 255, 255, 0.05);
259
+ border: 1px solid rgba(255, 255, 255, 0.08);
260
+ transition: background-color 0.2s ease, border-color 0.2s ease;
261
+ min-height: 2.5rem;
262
+ }
263
+
264
+ #available-files-grid .wrap label:hover {
265
+ background-color: rgba(90, 200, 250, 0.16);
266
+ border-color: rgba(90, 200, 250, 0.4);
267
+ }
268
+
269
+ #available-files-grid .wrap label span {
270
+ overflow: hidden;
271
+ text-overflow: ellipsis;
272
+ white-space: nowrap;
273
+ }
274
+
275
+ #available-files-section .gradio-loading,
276
+ #available-files-grid .gradio-loading {
277
+ position: absolute;
278
+ top: 0;
279
+ left: 0;
280
+ right: 0;
281
+ bottom: 0;
282
+ width: 100%;
283
+ height: 100%;
284
+ display: flex;
285
+ align-items: center;
286
+ justify-content: center;
287
+ background: rgba(10, 14, 23, 0.92);
288
+ border-radius: 0.75rem;
289
+ z-index: 999;
290
+ padding: 1.5rem;
291
+ pointer-events: auto;
292
+ }
293
+
294
+ #available-files-section .gradio-loading,
295
+ #available-files-grid .gradio-loading {
296
+ position: absolute;
297
+ top: 0;
298
+ left: 0;
299
+ right: 0;
300
+ bottom: 0;
301
+ width: 100%;
302
+ height: 100%;
303
+ display: flex;
304
+ align-items: center;
305
+ justify-content: center;
306
+ background: rgba(10, 14, 23, 0.92);
307
+ border-radius: 0.75rem;
308
+ z-index: 999;
309
+ padding: 1.5rem;
310
+ pointer-events: auto;
311
+ }
312
+
313
+ #available-files-section .gradio-loading > *,
314
+ #available-files-grid .gradio-loading > * {
315
+ width: 100%;
316
+ }
317
+
318
+ #available-files-section .gradio-loading progress,
319
+ #available-files-section .gradio-loading .progress-bar,
320
+ #available-files-section .gradio-loading .loading-progress,
321
+ #available-files-section .gradio-loading [role="progressbar"],
322
+ #available-files-section .gradio-loading .wrap,
323
+ #available-files-section .gradio-loading .inner,
324
+ #available-files-grid .gradio-loading progress,
325
+ #available-files-grid .gradio-loading .progress-bar,
326
+ #available-files-grid .gradio-loading .loading-progress,
327
+ #available-files-grid .gradio-loading [role="progressbar"],
328
+ #available-files-grid .gradio-loading .wrap,
329
+ #available-files-grid .gradio-loading .inner {
330
+ width: 100% !important;
331
+ max-width: none !important;
332
+ }
333
+
334
+ #available-files-section .gradio-loading .status,
335
+ #available-files-section .gradio-loading .message,
336
+ #available-files-section .gradio-loading .label,
337
+ #available-files-grid .gradio-loading .status,
338
+ #available-files-grid .gradio-loading .message,
339
+ #available-files-grid .gradio-loading .label {
340
+ text-align: center;
341
+ }
342
+
343
+ #date-browser-row {
344
+ gap: 0.75rem;
345
+ }
346
+
347
+ #date-browser-row .date-browser-column {
348
+ flex: 1 1 0%;
349
+ min-width: 0;
350
+ }
351
+
352
+ #date-browser-row .date-browser-column > .gradio-dropdown,
353
+ #date-browser-row .date-browser-column > .gradio-button {
354
+ width: 100%;
355
+ }
356
+
357
+ #date-browser-row .date-browser-column > .gradio-dropdown > div {
358
+ width: 100%;
359
+ }
360
+
361
+ #date-browser-row .date-browser-column .gradio-button {
362
+ justify-content: center;
363
+ }
364
+
365
+ #training-files-summary textarea {
366
+ max-height: 12rem;
367
+ overflow-y: auto;
368
+ }
369
+
370
+ #download-selected-button {
371
+ width: 100%;
372
+ position: relative;
373
+ z-index: 0;
374
+ }
375
+
376
+ #download-selected-button .gradio-button {
377
+ width: 100%;
378
+ justify-content: center;
379
+ }
380
+
381
+ #artifact-download-row {
382
+ gap: 0.75rem;
383
+ }
384
+
385
+ #artifact-download-row .artifact-download-button {
386
+ flex: 1 1 0%;
387
+ min-width: 0;
388
+ }
389
+
390
+ #artifact-download-row .artifact-download-button .gradio-button {
391
+ width: 100%;
392
+ justify-content: center;
393
+ }
394
+ """
395
+
396
+
397
+ def _github_cache_key(path: str) -> str:
398
+ return path or "__root__"
399
+
400
+
401
+ def _github_api_url(path: str) -> str:
402
+ clean_path = path.strip("/")
403
+ base = f"https://api.github.com/repos/{TRAINING_DATA_REPO}/contents"
404
+ if clean_path:
405
+ return f"{base}/{clean_path}?ref={TRAINING_DATA_BRANCH}"
406
+ return f"{base}?ref={TRAINING_DATA_BRANCH}"
407
+
408
+
409
+ def list_remote_directory(
410
+ path: str = "", *, force_refresh: bool = False
411
+ ) -> List[Dict[str, Any]]:
412
+ key = _github_cache_key(path)
413
+ if not force_refresh and key in GITHUB_CONTENT_CACHE:
414
+ return GITHUB_CONTENT_CACHE[key]
415
+
416
+ url = _github_api_url(path)
417
+ response = requests.get(url, timeout=30)
418
+ if response.status_code != 200:
419
+ raise RuntimeError(
420
+ f"GitHub API request failed for `{path or '.'}` (status {response.status_code})."
421
+ )
422
+
423
+ payload = response.json()
424
+ if not isinstance(payload, list):
425
+ raise RuntimeError(
426
+ "Unexpected GitHub API payload. Expected a directory listing."
427
+ )
428
+
429
+ GITHUB_CONTENT_CACHE[key] = payload
430
+ return payload
431
+
432
+
433
+ def list_remote_years(force_refresh: bool = False) -> List[str]:
434
+ entries = list_remote_directory("", force_refresh=force_refresh)
435
+ years = [item["name"] for item in entries if item.get("type") == "dir"]
436
+ return sorted(years)
437
+
438
+
439
+ def list_remote_months(year: str, *, force_refresh: bool = False) -> List[str]:
440
+ if not year:
441
+ return []
442
+ entries = list_remote_directory(year, force_refresh=force_refresh)
443
+ months = [item["name"] for item in entries if item.get("type") == "dir"]
444
+ return sorted(months)
445
+
446
+
447
+ def list_remote_days(
448
+ year: str, month: str, *, force_refresh: bool = False
449
+ ) -> List[str]:
450
+ if not year or not month:
451
+ return []
452
+ entries = list_remote_directory(f"{year}/{month}", force_refresh=force_refresh)
453
+ days = [item["name"] for item in entries if item.get("type") == "dir"]
454
+ return sorted(days)
455
+
456
+
457
+ def list_remote_files(
458
+ year: str, month: str, day: str, *, force_refresh: bool = False
459
+ ) -> List[str]:
460
+ if not year or not month or not day:
461
+ return []
462
+ entries = list_remote_directory(
463
+ f"{year}/{month}/{day}", force_refresh=force_refresh
464
+ )
465
+ files = [item["name"] for item in entries if item.get("type") == "file"]
466
+ return sorted(files)
467
+
468
+
469
+ def download_repository_file(year: str, month: str, day: str, filename: str) -> Path:
470
+ if not filename:
471
+ raise ValueError("Filename cannot be empty when downloading repository data.")
472
+
473
+ relative_parts = [part for part in (year, month, day, filename) if part]
474
+ if len(relative_parts) < 4:
475
+ raise ValueError("Provide year, month, day, and filename to download a CSV.")
476
+
477
+ relative_path = "/".join(relative_parts)
478
+ raw_url = (
479
+ f"https://raw.githubusercontent.com/{TRAINING_DATA_REPO}/"
480
+ f"{TRAINING_DATA_BRANCH}/{relative_path}"
481
+ )
482
+
483
+ response = requests.get(raw_url, stream=True, timeout=120)
484
+ if response.status_code != 200:
485
+ raise RuntimeError(
486
+ f"Failed to download `{relative_path}` (status {response.status_code})."
487
+ )
488
+
489
+ target_dir = TRAINING_DATA_DIR.joinpath(year, month, day)
490
+ target_dir.mkdir(parents=True, exist_ok=True)
491
+ target_path = target_dir / filename
492
+
493
+ with open(target_path, "wb") as handle:
494
+ for chunk in response.iter_content(chunk_size=1 << 20):
495
+ if chunk:
496
+ handle.write(chunk)
497
+
498
+ return target_path
499
+
500
+
501
+ def _normalise_header(name: str) -> str:
502
+ return str(name).strip().lower()
503
+
504
+
505
+ def guess_label_from_columns(
506
+ columns: Sequence[str], preferred: Optional[str] = None
507
+ ) -> Optional[str]:
508
+ if not columns:
509
+ return preferred
510
+
511
+ lookup = {_normalise_header(col): str(col) for col in columns}
512
+
513
+ if preferred:
514
+ preferred_stripped = preferred.strip()
515
+ for col in columns:
516
+ if str(col).strip() == preferred_stripped:
517
+ return str(col)
518
+ preferred_norm = _normalise_header(preferred)
519
+ if preferred_norm in lookup:
520
+ return lookup[preferred_norm]
521
+
522
+ for guess in TRAINING_LABEL_GUESSES:
523
+ guess_norm = _normalise_header(guess)
524
+ if guess_norm in lookup:
525
+ return lookup[guess_norm]
526
+
527
+ for col in columns:
528
+ if _normalise_header(col).startswith("fault"):
529
+ return str(col)
530
+
531
+ return str(columns[0])
532
+
533
+
534
+ def summarise_training_files(paths: Sequence[str], notes: Sequence[str]) -> str:
535
+ lines = [Path(path).name for path in paths]
536
+ lines.extend(notes)
537
+ return "\n".join(lines) if lines else "No training files available."
538
+
539
+
540
+ def read_training_status(status_file_path: str) -> str:
541
+ """Read the current training status from file."""
542
+ try:
543
+ if Path(status_file_path).exists():
544
+ with open(status_file_path, "r") as f:
545
+ return f.read().strip()
546
+ except Exception:
547
+ pass
548
+ return "Training status unavailable"
549
+
550
+
551
+ def _persist_uploaded_file(file_obj) -> Optional[Path]:
552
+ if file_obj is None:
553
+ return None
554
+
555
+ if isinstance(file_obj, (str, Path)):
556
+ source = Path(file_obj)
557
+ original_name = source.name
558
+ else:
559
+ source = Path(getattr(file_obj, "name", "") or getattr(file_obj, "path", ""))
560
+ original_name = getattr(file_obj, "orig_name", source.name) or source.name
561
+ if not source or not source.exists():
562
+ return None
563
+
564
+ original_name = Path(original_name).name or source.name
565
+
566
+ base_path = Path(original_name)
567
+ destination = TRAINING_UPLOAD_DIR / base_path.name
568
+ counter = 1
569
+ while destination.exists():
570
+ suffix = base_path.suffix or ".csv"
571
+ destination = TRAINING_UPLOAD_DIR / f"{base_path.stem}_{counter}{suffix}"
572
+ counter += 1
573
+
574
+ shutil.copy2(source, destination)
575
+ return destination
576
+
577
+
578
+ def prepare_training_paths(
579
+ paths: Sequence[str], current_label: str, cleanup_missing: bool = False
580
+ ):
581
+ valid_paths: List[str] = []
582
+ notes: List[str] = []
583
+ columns_map: Dict[str, str] = {}
584
+ for path in paths:
585
+ try:
586
+ df = load_measurement_csv(path)
587
+ except Exception as exc: # pragma: no cover - user file diagnostics
588
+ notes.append(f"⚠️ Skipped {Path(path).name}: {exc}")
589
+ if cleanup_missing:
590
+ try:
591
+ Path(path).unlink(missing_ok=True)
592
+ except Exception:
593
+ pass
594
+ continue
595
+ valid_paths.append(str(path))
596
+ for col in df.columns:
597
+ columns_map[_normalise_header(col)] = str(col)
598
+
599
+ summary = summarise_training_files(valid_paths, notes)
600
+ preferred = current_label or LABEL_COLUMN
601
+ dropdown_choices = (
602
+ sorted(columns_map.values()) if columns_map else [preferred or LABEL_COLUMN]
603
+ )
604
+ guessed = guess_label_from_columns(dropdown_choices, preferred)
605
+ dropdown_value = guessed or preferred or LABEL_COLUMN
606
+
607
+ return (
608
+ valid_paths,
609
+ summary,
610
+ gr.update(choices=dropdown_choices, value=dropdown_value),
611
+ )
612
+
613
+
614
+ def append_training_files(new_files, existing_paths: Sequence[str], current_label: str):
615
+ if isinstance(existing_paths, (str, Path)):
616
+ paths: List[str] = [str(existing_paths)]
617
+ elif existing_paths is None:
618
+ paths = []
619
+ else:
620
+ paths = list(existing_paths)
621
+ if new_files:
622
+ for file in new_files:
623
+ persisted = _persist_uploaded_file(file)
624
+ if persisted is None:
625
+ continue
626
+ path_str = str(persisted)
627
+ if path_str not in paths:
628
+ paths.append(path_str)
629
+
630
+ return prepare_training_paths(paths, current_label, cleanup_missing=True)
631
+
632
+
633
+ def load_repository_training_files(current_label: str, force_refresh: bool = False):
634
+ if force_refresh:
635
+ # Clearing the cache is enough because downloads are now on-demand.
636
+ for cached in list(TRAINING_DATA_DIR.glob("*")):
637
+ # On refresh we keep previously downloaded files; no deletion required.
638
+ # The flag triggers downstream UI updates only.
639
+ break
640
+
641
+ csv_paths = sorted(
642
+ str(path) for path in TRAINING_DATA_DIR.rglob("*.csv") if path.is_file()
643
+ )
644
+ if not csv_paths:
645
+ message = (
646
+ "No local database CSVs are available yet. Use the database browser "
647
+ "below to download specific days before training."
648
+ )
649
+ default_label = current_label or LABEL_COLUMN or "Fault"
650
+ return (
651
+ [],
652
+ message,
653
+ gr.update(choices=[default_label], value=default_label),
654
+ message,
655
+ )
656
+
657
+ valid_paths, summary, label_update = prepare_training_paths(
658
+ csv_paths, current_label, cleanup_missing=False
659
+ )
660
+
661
+ info = (
662
+ f"Ready with {len(valid_paths)} CSV file(s) cached locally under "
663
+ f"the database cache `{TRAINING_DATA_DIR}`."
664
+ )
665
+
666
+ return valid_paths, summary, label_update, info
667
+
668
+
669
+ def refresh_remote_browser(force_refresh: bool = False):
670
+ if force_refresh:
671
+ GITHUB_CONTENT_CACHE.clear()
672
+ try:
673
+ years = list_remote_years(force_refresh=force_refresh)
674
+ if years:
675
+ message = "Select a year, month, and day to list available CSV files."
676
+ else:
677
+ message = (
678
+ "⚠️ No directories were found in the database root. Verify the upstream "
679
+ "structure."
680
+ )
681
+ return (
682
+ gr.update(choices=years, value=None),
683
+ gr.update(choices=[], value=None),
684
+ gr.update(choices=[], value=None),
685
+ gr.update(choices=[], value=[]),
686
+ message,
687
+ )
688
+ except Exception as exc:
689
+ return (
690
+ gr.update(choices=[], value=None),
691
+ gr.update(choices=[], value=None),
692
+ gr.update(choices=[], value=None),
693
+ gr.update(choices=[], value=[]),
694
+ f"⚠️ Failed to query database: {exc}",
695
+ )
696
+
697
+
698
+ def on_year_change(year: Optional[str]):
699
+ if not year:
700
+ return (
701
+ gr.update(choices=[], value=None),
702
+ gr.update(choices=[], value=None),
703
+ gr.update(choices=[], value=[]),
704
+ "Select a year to continue.",
705
+ )
706
+ try:
707
+ months = list_remote_months(year)
708
+ message = (
709
+ f"Year `{year}` selected. Choose a month to drill down."
710
+ if months
711
+ else f"⚠️ No months available under `{year}`."
712
+ )
713
+ return (
714
+ gr.update(choices=months, value=None),
715
+ gr.update(choices=[], value=None),
716
+ gr.update(choices=[], value=[]),
717
+ message,
718
+ )
719
+ except Exception as exc:
720
+ return (
721
+ gr.update(choices=[], value=None),
722
+ gr.update(choices=[], value=None),
723
+ gr.update(choices=[], value=[]),
724
+ f"⚠️ Failed to list months: {exc}",
725
+ )
726
+
727
+
728
+ def on_month_change(year: Optional[str], month: Optional[str]):
729
+ if not year or not month:
730
+ return (
731
+ gr.update(choices=[], value=None),
732
+ gr.update(choices=[], value=[]),
733
+ "Select a month to continue.",
734
+ )
735
+ try:
736
+ days = list_remote_days(year, month)
737
+ message = (
738
+ f"Month `{year}/{month}` ready. Pick a day to view files."
739
+ if days
740
+ else f"⚠️ No day folders found under `{year}/{month}`."
741
+ )
742
+ return (
743
+ gr.update(choices=days, value=None),
744
+ gr.update(choices=[], value=[]),
745
+ message,
746
+ )
747
+ except Exception as exc:
748
+ return (
749
+ gr.update(choices=[], value=None),
750
+ gr.update(choices=[], value=[]),
751
+ f"⚠️ Failed to list days: {exc}",
752
+ )
753
+
754
+
755
+ def on_day_change(year: Optional[str], month: Optional[str], day: Optional[str]):
756
+ if not year or not month or not day:
757
+ return (
758
+ gr.update(choices=[], value=[]),
759
+ "Select a day to load file names.",
760
+ )
761
+ try:
762
+ files = list_remote_files(year, month, day)
763
+ message = (
764
+ f"{len(files)} file(s) available for `{year}/{month}/{day}`."
765
+ if files
766
+ else f"⚠️ No CSV files found under `{year}/{month}/{day}`."
767
+ )
768
+ return (
769
+ gr.update(choices=files, value=[]),
770
+ message,
771
+ )
772
+ except Exception as exc:
773
+ return (
774
+ gr.update(choices=[], value=[]),
775
+ f"⚠️ Failed to list files: {exc}",
776
+ )
777
+
778
+
779
+ def download_selected_files(
780
+ year: Optional[str],
781
+ month: Optional[str],
782
+ day: Optional[str],
783
+ filenames: Sequence[str],
784
+ current_label: str,
785
+ ):
786
+ if not filenames:
787
+ message = "Select at least one CSV before downloading."
788
+ local = load_repository_training_files(current_label)
789
+ return (*local, gr.update(), message)
790
+
791
+ success: List[str] = []
792
+ notes: List[str] = []
793
+ for filename in filenames:
794
+ try:
795
+ path = download_repository_file(
796
+ year or "", month or "", day or "", filename
797
+ )
798
+ success.append(str(path))
799
+ except Exception as exc:
800
+ notes.append(f"⚠️ {filename}: {exc}")
801
+
802
+ local = load_repository_training_files(current_label)
803
+
804
+ message_lines = []
805
+ if success:
806
+ message_lines.append(
807
+ f"Downloaded {len(success)} file(s) to the database cache `{TRAINING_DATA_DIR}`."
808
+ )
809
+ if notes:
810
+ message_lines.extend(notes)
811
+ if not message_lines:
812
+ message_lines.append("No files were downloaded.")
813
+
814
+ return (*local, gr.update(value=[]), "\n".join(message_lines))
815
+
816
+
817
+ def download_day_bundle(
818
+ year: Optional[str],
819
+ month: Optional[str],
820
+ day: Optional[str],
821
+ current_label: str,
822
+ ):
823
+ if not (year and month and day):
824
+ local = load_repository_training_files(current_label)
825
+ return (
826
+ *local,
827
+ gr.update(),
828
+ "Select a year, month, and day before downloading an entire day.",
829
+ )
830
+
831
+ try:
832
+ files = list_remote_files(year, month, day)
833
+ except Exception as exc:
834
+ local = load_repository_training_files(current_label)
835
+ return (
836
+ *local,
837
+ gr.update(),
838
+ f"⚠️ Failed to list CSVs for `{year}/{month}/{day}`: {exc}",
839
+ )
840
+
841
+ if not files:
842
+ local = load_repository_training_files(current_label)
843
+ return (
844
+ *local,
845
+ gr.update(),
846
+ f"No CSV files were found for `{year}/{month}/{day}`.",
847
+ )
848
+
849
+ result = list(download_selected_files(year, month, day, files, current_label))
850
+ result[-1] = (
851
+ f"Downloaded all {len(files)} CSV file(s) for `{year}/{month}/{day}`.\n"
852
+ f"{result[-1]}"
853
+ )
854
+ return tuple(result)
855
+
856
+
857
+ def download_month_bundle(
858
+ year: Optional[str], month: Optional[str], current_label: str
859
+ ):
860
+ if not (year and month):
861
+ local = load_repository_training_files(current_label)
862
+ return (
863
+ *local,
864
+ gr.update(),
865
+ "Select a year and month before downloading an entire month.",
866
+ )
867
+
868
+ try:
869
+ days = list_remote_days(year, month)
870
+ except Exception as exc:
871
+ local = load_repository_training_files(current_label)
872
+ return (
873
+ *local,
874
+ gr.update(),
875
+ f"⚠️ Failed to enumerate days for `{year}/{month}`: {exc}",
876
+ )
877
+
878
+ if not days:
879
+ local = load_repository_training_files(current_label)
880
+ return (
881
+ *local,
882
+ gr.update(),
883
+ f"No day folders were found for `{year}/{month}`.",
884
+ )
885
+
886
+ downloaded = 0
887
+ notes: List[str] = []
888
+ for day in days:
889
+ try:
890
+ files = list_remote_files(year, month, day)
891
+ except Exception as exc:
892
+ notes.append(f"⚠️ Failed to list `{year}/{month}/{day}`: {exc}")
893
+ continue
894
+ if not files:
895
+ notes.append(f"⚠️ No CSV files in `{year}/{month}/{day}`.")
896
+ continue
897
+ for filename in files:
898
+ try:
899
+ download_repository_file(year, month, day, filename)
900
+ downloaded += 1
901
+ except Exception as exc:
902
+ notes.append(f"⚠️ {year}/{month}/{day}/{filename}: {exc}")
903
+
904
+ local = load_repository_training_files(current_label)
905
+ message_lines = []
906
+ if downloaded:
907
+ message_lines.append(
908
+ f"Downloaded {downloaded} CSV file(s) for `{year}/{month}` into the "
909
+ f"database cache `{TRAINING_DATA_DIR}`."
910
+ )
911
+ message_lines.extend(notes)
912
+ if not message_lines:
913
+ message_lines.append("No files were downloaded.")
914
+
915
+ return (*local, gr.update(value=[]), "\n".join(message_lines))
916
+
917
+
918
+ def download_year_bundle(year: Optional[str], current_label: str):
919
+ if not year:
920
+ local = load_repository_training_files(current_label)
921
+ return (
922
+ *local,
923
+ gr.update(),
924
+ "Select a year before downloading an entire year of CSVs.",
925
+ )
926
+
927
+ try:
928
+ months = list_remote_months(year)
929
+ except Exception as exc:
930
+ local = load_repository_training_files(current_label)
931
+ return (
932
+ *local,
933
+ gr.update(),
934
+ f"⚠️ Failed to enumerate months for `{year}`: {exc}",
935
+ )
936
+
937
+ if not months:
938
+ local = load_repository_training_files(current_label)
939
+ return (
940
+ *local,
941
+ gr.update(),
942
+ f"No month folders were found for `{year}`.",
943
+ )
944
+
945
+ downloaded = 0
946
+ notes: List[str] = []
947
+ for month in months:
948
+ try:
949
+ days = list_remote_days(year, month)
950
+ except Exception as exc:
951
+ notes.append(f"⚠️ Failed to list `{year}/{month}`: {exc}")
952
+ continue
953
+ if not days:
954
+ notes.append(f"⚠️ No day folders in `{year}/{month}`.")
955
+ continue
956
+ for day in days:
957
+ try:
958
+ files = list_remote_files(year, month, day)
959
+ except Exception as exc:
960
+ notes.append(f"⚠️ Failed to list `{year}/{month}/{day}`: {exc}")
961
+ continue
962
+ if not files:
963
+ notes.append(f"⚠️ No CSV files in `{year}/{month}/{day}`.")
964
+ continue
965
+ for filename in files:
966
+ try:
967
+ download_repository_file(year, month, day, filename)
968
+ downloaded += 1
969
+ except Exception as exc:
970
+ notes.append(f"⚠️ {year}/{month}/{day}/{filename}: {exc}")
971
+
972
+ local = load_repository_training_files(current_label)
973
+ message_lines = []
974
+ if downloaded:
975
+ message_lines.append(
976
+ f"Downloaded {downloaded} CSV file(s) for `{year}` into the "
977
+ f"database cache `{TRAINING_DATA_DIR}`."
978
+ )
979
+ message_lines.extend(notes)
980
+ if not message_lines:
981
+ message_lines.append("No files were downloaded.")
982
+
983
+ return (*local, gr.update(value=[]), "\n".join(message_lines))
984
+
985
+
986
+ def clear_downloaded_cache(current_label: str):
987
+ status_message = ""
988
+ try:
989
+ if TRAINING_DATA_DIR.exists():
990
+ shutil.rmtree(TRAINING_DATA_DIR)
991
+ TRAINING_DATA_DIR.mkdir(parents=True, exist_ok=True)
992
+ status_message = (
993
+ f"Cleared all downloaded CSVs from database cache `{TRAINING_DATA_DIR}`."
994
+ )
995
+ except Exception as exc:
996
+ status_message = f"⚠️ Failed to clear database cache: {exc}"
997
+
998
+ local = load_repository_training_files(current_label, force_refresh=True)
999
+ remote = list(refresh_remote_browser(force_refresh=False))
1000
+ if status_message:
1001
+ previous = remote[-1]
1002
+ if isinstance(previous, str) and previous:
1003
+ remote[-1] = f"{status_message}\n{previous}"
1004
+ else:
1005
+ remote[-1] = status_message
1006
+
1007
+ return (*local, *remote)
1008
+
1009
+
1010
+ def normalise_output_directory(directory: Optional[str]) -> Path:
1011
+ base = Path(directory or MODEL_OUTPUT_DIR)
1012
+ base = base.expanduser()
1013
+ if not base.is_absolute():
1014
+ base = (Path.cwd() / base).resolve()
1015
+ return base
1016
+
1017
+
1018
+ def resolve_output_path(
1019
+ directory: Optional[Union[Path, str]], filename: Optional[str], fallback: str
1020
+ ) -> Path:
1021
+ if isinstance(directory, Path):
1022
+ base = directory
1023
+ else:
1024
+ base = normalise_output_directory(directory)
1025
+ candidate = Path(filename or "").expanduser()
1026
+ if str(candidate):
1027
+ if candidate.is_absolute():
1028
+ return candidate
1029
+ return (base / candidate).resolve()
1030
+ return (base / fallback).resolve()
1031
+
1032
+
1033
+ ARTIFACT_FILE_EXTENSIONS: Tuple[str, ...] = (
1034
+ ".keras",
1035
+ ".h5",
1036
+ ".joblib",
1037
+ ".pkl",
1038
+ ".json",
1039
+ ".onnx",
1040
+ ".zip",
1041
+ ".txt",
1042
+ )
1043
+
1044
+
1045
+ def gather_directory_choices(current: Optional[str]) -> Tuple[List[str], str]:
1046
+ base = normalise_output_directory(current or str(MODEL_OUTPUT_DIR))
1047
+ candidates = {str(base)}
1048
+ try:
1049
+ for candidate in base.parent.iterdir():
1050
+ if candidate.is_dir():
1051
+ candidates.add(str(candidate.resolve()))
1052
+ except Exception:
1053
+ pass
1054
+ return sorted(candidates), str(base)
1055
+
1056
+
1057
+ def gather_artifact_choices(
1058
+ directory: Optional[str], selection: Optional[str] = None
1059
+ ) -> Tuple[List[Tuple[str, str]], Optional[str]]:
1060
+ base = normalise_output_directory(directory)
1061
+ choices: List[Tuple[str, str]] = []
1062
+ selected_value: Optional[str] = None
1063
+ if base.exists():
1064
+ try:
1065
+ artifacts = sorted(
1066
+ [
1067
+ path
1068
+ for path in base.iterdir()
1069
+ if path.is_file()
1070
+ and (
1071
+ not ARTIFACT_FILE_EXTENSIONS
1072
+ or path.suffix.lower() in ARTIFACT_FILE_EXTENSIONS
1073
+ )
1074
+ ],
1075
+ key=lambda path: path.name.lower(),
1076
+ )
1077
+ choices = [(artifact.name, str(artifact)) for artifact in artifacts]
1078
+ except Exception:
1079
+ choices = []
1080
+
1081
+ if selection and any(value == selection for _, value in choices):
1082
+ selected_value = selection
1083
+ elif choices:
1084
+ selected_value = choices[0][1]
1085
+
1086
+ return choices, selected_value
1087
+
1088
+
1089
+ def download_button_state(path: Optional[Union[str, Path]]):
1090
+ if not path:
1091
+ return gr.update(value=None, visible=False)
1092
+ candidate = Path(path)
1093
+ if candidate.exists():
1094
+ return gr.update(value=str(candidate), visible=True)
1095
+ return gr.update(value=None, visible=False)
1096
+
1097
+
1098
+ def clear_training_files():
1099
+ default_label = LABEL_COLUMN or "Fault"
1100
+ for cached_file in TRAINING_UPLOAD_DIR.glob("*"):
1101
+ try:
1102
+ if cached_file.is_file():
1103
+ cached_file.unlink(missing_ok=True)
1104
+ except Exception:
1105
+ pass
1106
+ return (
1107
+ [],
1108
+ "No training files selected.",
1109
+ gr.update(choices=[default_label], value=default_label),
1110
+ gr.update(value=None),
1111
+ )
1112
+
1113
+
1114
+ PROJECT_OVERVIEW_MD = """
1115
+ ## Project Overview
1116
+
1117
+ This project focuses on classifying faults in electrical transmission lines and
1118
+ grid-connected photovoltaic (PV) systems by combining ensemble learning
1119
+ techniques with deep neural architectures.
1120
+
1121
+ ## Datasets
1122
+
1123
+ ### Transmission Line Fault Dataset
1124
+ - 134,406 samples collected from Phasor Measurement Units (PMUs)
1125
+ - 14 monitored channels covering currents, voltages, magnitudes, frequency, and phase angles
1126
+ - Labels span symmetrical and asymmetrical faults: NF, L-G, LL, LL-G, LLL, and LLL-G
1127
+ - Time span: 0 to 5.7 seconds with high-frequency sampling
1128
+
1129
+ ### Grid-Connected PV System Fault Dataset
1130
+ - 2,163,480 samples from 16 experimental scenarios
1131
+ - 14 features including PV array measurements (Ipv, Vpv, Vdc), three-phase currents/voltages, aggregate magnitudes (Iabc, Vabc), and frequency indicators (If, Vf)
1132
+ - Captures array, inverter, grid anomaly, feedback sensor, and MPPT controller faults at 9.9989 μs sampling intervals
1133
+
1134
+ ## Data Format Quick Reference
1135
+
1136
+ Each measurement file may be comma or tab separated and typically exposes the
1137
+ following ordered columns:
1138
+
1139
+ 1. `Timestamp`
1140
+ 2. `[325] UPMU_SUB22:FREQ` – system frequency (Hz)
1141
+ 3. `[326] UPMU_SUB22:DFDT` – frequency rate-of-change
1142
+ 4. `[327] UPMU_SUB22:FLAG` – PMU status flag
1143
+ 5. `[328] UPMU_SUB22-L1:MAG` – phase A voltage magnitude
1144
+ 6. `[329] UPMU_SUB22-L1:ANG` – phase A voltage angle
1145
+ 7. `[330] UPMU_SUB22-L2:MAG` – phase B voltage magnitude
1146
+ 8. `[331] UPMU_SUB22-L2:ANG` – phase B voltage angle
1147
+ 9. `[332] UPMU_SUB22-L3:MAG` – phase C voltage magnitude
1148
+ 10. `[333] UPMU_SUB22-L3:ANG` – phase C voltage angle
1149
+ 11. `[334] UPMU_SUB22-C1:MAG` – phase A current magnitude
1150
+ 12. `[335] UPMU_SUB22-C1:ANG` – phase A current angle
1151
+ 13. `[336] UPMU_SUB22-C2:MAG` – phase B current magnitude
1152
+ 14. `[337] UPMU_SUB22-C2:ANG` – phase B current angle
1153
+ 15. `[338] UPMU_SUB22-C3:MAG` – phase C current magnitude
1154
+ 16. `[339] UPMU_SUB22-C3:ANG` – phase C current angle
1155
+
1156
+ The training tab automatically downloads the latest CSV exports from the
1157
+ `VincentCroft/ThesisModelData` repository and concatenates them before building
1158
+ sliding windows.
1159
+
1160
+ ## Models Developed
1161
+
1162
+ 1. **Support Vector Machine (SVM)** – provides the classical machine learning baseline with balanced accuracy across both datasets (85% PMU / 83% PV).
1163
+ 2. **CNN-LSTM** – couples convolutional feature extraction with temporal memory, achieving 92% PMU / 89% PV accuracy.
1164
+ 3. **Temporal Convolutional Network (TCN)** – leverages dilated convolutions for long-range context and delivers the best trade-off between accuracy and training time (94% PMU / 91% PV).
1165
+
1166
+ ## Results Summary
1167
+
1168
+ - **Transmission Line Fault Classification**: SVM 85%, CNN-LSTM 92%, TCN 94%
1169
+ - **PV System Fault Classification**: SVM 83%, CNN-LSTM 89%, TCN 91%
1170
+
1171
+ Use the **Inference** tab to score new PMU/PV windows and the **Training** tab to
1172
+ fine-tune or retrain any of the supported models directly within Hugging Face
1173
+ Spaces. The logs panel will surface TensorBoard archives whenever deep-learning
1174
+ models are trained.
1175
+ """
1176
+
1177
+
1178
+ def load_measurement_csv(path: str) -> pd.DataFrame:
1179
+ """Read a PMU/PV measurement file with flexible separators and column mapping."""
1180
+
1181
+ try:
1182
+ df = pd.read_csv(path, sep=None, engine="python", encoding="utf-8-sig")
1183
+ except Exception:
1184
+ df = None
1185
+ for separator in ("\t", ",", ";"):
1186
+ try:
1187
+ df = pd.read_csv(
1188
+ path, sep=separator, engine="python", encoding="utf-8-sig"
1189
+ )
1190
+ break
1191
+ except Exception:
1192
+ df = None
1193
+ if df is None:
1194
+ raise
1195
+
1196
+ # Clean column names
1197
+ df.columns = [str(col).strip() for col in df.columns]
1198
+
1199
+ print(f"Loaded CSV with {len(df)} rows and {len(df.columns)} columns")
1200
+ print(f"Columns: {list(df.columns)}")
1201
+ print(f"Data shape: {df.shape}")
1202
+
1203
+ # Check if we have enough data for training
1204
+ if len(df) < 100:
1205
+ print(
1206
+ f"Warning: Only {len(df)} rows of data. Recommend at least 1000 rows for effective training."
1207
+ )
1208
+
1209
+ # Check for label column
1210
+ has_label = any(
1211
+ col.lower() in ["fault", "label", "class", "target"] for col in df.columns
1212
+ )
1213
+ if not has_label:
1214
+ print(
1215
+ "Warning: No label column found. Adding dummy 'Fault' column with value 'Normal' for all samples."
1216
+ )
1217
+ df["Fault"] = "Normal" # Add dummy label for training
1218
+
1219
+ # Create column mapping - map similar column names to expected format
1220
+ column_mapping = {}
1221
+ expected_cols = list(REQUIRED_PMU_COLUMNS)
1222
+
1223
+ # If we have at least the right number of numeric columns after Timestamp, use positional mapping
1224
+ if "Timestamp" in df.columns:
1225
+ numeric_cols = [col for col in df.columns if col != "Timestamp"]
1226
+ if len(numeric_cols) >= len(expected_cols):
1227
+ # Map by position (after Timestamp)
1228
+ for i, expected_col in enumerate(expected_cols):
1229
+ if i < len(numeric_cols):
1230
+ column_mapping[numeric_cols[i]] = expected_col
1231
+
1232
+ # Rename columns to match expected format
1233
+ df = df.rename(columns=column_mapping)
1234
+
1235
+ # Check if we have the required columns after mapping
1236
+ missing = [col for col in REQUIRED_PMU_COLUMNS if col not in df.columns]
1237
+ if missing:
1238
+ # If still missing, try a more flexible approach
1239
+ available_numeric = df.select_dtypes(include=[np.number]).columns.tolist()
1240
+ if len(available_numeric) >= len(expected_cols):
1241
+ # Use the first N numeric columns
1242
+ for i, expected_col in enumerate(expected_cols):
1243
+ if i < len(available_numeric):
1244
+ if available_numeric[i] not in df.columns:
1245
+ continue
1246
+ df = df.rename(columns={available_numeric[i]: expected_col})
1247
+
1248
+ # Recheck missing columns
1249
+ missing = [col for col in REQUIRED_PMU_COLUMNS if col not in df.columns]
1250
+
1251
+ if missing:
1252
+ missing_str = ", ".join(missing)
1253
+ available_str = ", ".join(df.columns.tolist())
1254
+ raise ValueError(
1255
+ f"Missing required PMU feature columns: {missing_str}. "
1256
+ f"Available columns: {available_str}. "
1257
+ "Please ensure your CSV has the correct format with Timestamp followed by PMU measurements."
1258
+ )
1259
+
1260
+ return df
1261
+
1262
+
1263
+ def apply_metadata(metadata: Dict[str, Any]) -> None:
1264
+ global FEATURE_COLUMNS, LABEL_CLASSES, LABEL_COLUMN, SEQUENCE_LENGTH, DEFAULT_WINDOW_STRIDE, MODEL_TYPE, MODEL_FORMAT
1265
+ FEATURE_COLUMNS = [
1266
+ str(col) for col in metadata.get("feature_columns", DEFAULT_FEATURE_COLUMNS)
1267
+ ]
1268
+ LABEL_CLASSES = [str(label) for label in metadata.get("label_classes", [])]
1269
+ LABEL_COLUMN = str(metadata.get("label_column", "Fault"))
1270
+ SEQUENCE_LENGTH = int(metadata.get("sequence_length", DEFAULT_SEQUENCE_LENGTH))
1271
+ DEFAULT_WINDOW_STRIDE = int(metadata.get("stride", DEFAULT_STRIDE))
1272
+ MODEL_TYPE = str(metadata.get("model_type", "cnn_lstm")).lower()
1273
+ MODEL_FORMAT = str(
1274
+ metadata.get("model_format", "joblib" if MODEL_TYPE == "svm" else "keras")
1275
+ ).lower()
1276
+
1277
+
1278
+ apply_metadata(METADATA)
1279
+
1280
+
1281
+ def sync_label_classes_from_model(model: Optional[object]) -> None:
1282
+ global LABEL_CLASSES
1283
+ if model is None:
1284
+ return
1285
+ if hasattr(model, "classes_"):
1286
+ LABEL_CLASSES = [str(label) for label in getattr(model, "classes_")]
1287
+ elif not LABEL_CLASSES and hasattr(model, "output_shape"):
1288
+ LABEL_CLASSES = [str(i) for i in range(int(model.output_shape[-1]))]
1289
+
1290
+
1291
+ # Load model and scaler with error handling
1292
+ print("Loading model and scaler...")
1293
+ try:
1294
+ MODEL = try_load_model(MODEL_PATH, MODEL_TYPE, MODEL_FORMAT)
1295
+ print(f"Model loaded: {MODEL is not None}")
1296
+ except Exception as e:
1297
+ print(f"Model loading failed: {e}")
1298
+ MODEL = None
1299
+
1300
+ try:
1301
+ SCALER = try_load_scaler(SCALER_PATH)
1302
+ print(f"Scaler loaded: {SCALER is not None}")
1303
+ except Exception as e:
1304
+ print(f"Scaler loading failed: {e}")
1305
+ SCALER = None
1306
+
1307
+ try:
1308
+ sync_label_classes_from_model(MODEL)
1309
+ print("Label classes synchronized")
1310
+ except Exception as e:
1311
+ print(f"Label sync failed: {e}")
1312
+
1313
+ print("Application initialization completed.")
1314
+ print(
1315
+ f"Ready to start Gradio interface. Model available: {MODEL is not None}, Scaler available: {SCALER is not None}"
1316
+ )
1317
+
1318
+
1319
+ def refresh_artifacts(model_path: Path, scaler_path: Path, metadata_path: Path) -> None:
1320
+ global MODEL_PATH, SCALER_PATH, METADATA_PATH, MODEL, SCALER, METADATA
1321
+ MODEL_PATH = model_path
1322
+ SCALER_PATH = scaler_path
1323
+ METADATA_PATH = metadata_path
1324
+ METADATA = load_metadata(metadata_path)
1325
+ apply_metadata(METADATA)
1326
+ MODEL = try_load_model(model_path, MODEL_TYPE, MODEL_FORMAT)
1327
+ SCALER = try_load_scaler(scaler_path)
1328
+ sync_label_classes_from_model(MODEL)
1329
+
1330
+
1331
+ # --------------------------------------------------------------------------------------
1332
+ # Pre-processing helpers
1333
+ # --------------------------------------------------------------------------------------
1334
+
1335
+
1336
+ def ensure_ready():
1337
+ if MODEL is None or SCALER is None:
1338
+ raise RuntimeError(
1339
+ "The model and feature scaler are not available. Upload the trained model "
1340
+ "(for example `pmu_cnn_lstm_model.keras`, `pmu_tcn_model.keras`, or `pmu_svm_model.joblib`), "
1341
+ "the feature scaler (`pmu_feature_scaler.pkl`), and the metadata JSON (`pmu_metadata.json`) to the Space root "
1342
+ "or configure the Hugging Face Hub environment variables so the artifacts can be downloaded "
1343
+ "automatically."
1344
+ )
1345
+
1346
+
1347
+ def parse_text_features(text: str) -> np.ndarray:
1348
+ cleaned = re.sub(r"[;\n\t]+", ",", text.strip())
1349
+ arr = np.fromstring(cleaned, sep=",")
1350
+ if arr.size == 0:
1351
+ raise ValueError(
1352
+ "No feature values were parsed. Please enter comma-separated numbers."
1353
+ )
1354
+ return arr.astype(np.float32)
1355
+
1356
+
1357
+ def apply_scaler(sequences: np.ndarray) -> np.ndarray:
1358
+ if SCALER is None:
1359
+ return sequences
1360
+ shape = sequences.shape
1361
+ flattened = sequences.reshape(-1, shape[-1])
1362
+ scaled = SCALER.transform(flattened)
1363
+ return scaled.reshape(shape)
1364
+
1365
+
1366
+ def make_sliding_windows(
1367
+ data: np.ndarray, sequence_length: int, stride: int
1368
+ ) -> np.ndarray:
1369
+ if data.shape[0] < sequence_length:
1370
+ raise ValueError(
1371
+ f"The dataset contains {data.shape[0]} rows which is less than the requested sequence "
1372
+ f"length {sequence_length}. Provide more samples or reduce the sequence length."
1373
+ )
1374
+ windows = [
1375
+ data[start : start + sequence_length]
1376
+ for start in range(0, data.shape[0] - sequence_length + 1, stride)
1377
+ ]
1378
+ return np.stack(windows)
1379
+
1380
+
1381
+ def dataframe_to_sequences(
1382
+ df: pd.DataFrame,
1383
+ *,
1384
+ sequence_length: int,
1385
+ stride: int,
1386
+ feature_columns: Sequence[str],
1387
+ drop_label: bool = True,
1388
+ ) -> np.ndarray:
1389
+ work_df = df.copy()
1390
+ if drop_label and LABEL_COLUMN in work_df.columns:
1391
+ work_df = work_df.drop(columns=[LABEL_COLUMN])
1392
+ if "Timestamp" in work_df.columns:
1393
+ work_df = work_df.sort_values("Timestamp")
1394
+
1395
+ available_cols = [c for c in feature_columns if c in work_df.columns]
1396
+ n_features = len(feature_columns)
1397
+ if available_cols and len(available_cols) == n_features:
1398
+ array = work_df[available_cols].astype(np.float32).to_numpy()
1399
+ return make_sliding_windows(array, sequence_length, stride)
1400
+
1401
+ numeric_df = work_df.select_dtypes(include=[np.number])
1402
+ array = numeric_df.astype(np.float32).to_numpy()
1403
+ if array.shape[1] == n_features * sequence_length:
1404
+ return array.reshape(array.shape[0], sequence_length, n_features)
1405
+ if sequence_length == 1 and array.shape[1] == n_features:
1406
+ return array.reshape(array.shape[0], 1, n_features)
1407
+ raise ValueError(
1408
+ "CSV columns do not match the expected feature layout. Include the full PMU feature set "
1409
+ "or provide pre-shaped sliding window data."
1410
+ )
1411
+
1412
+
1413
+ def label_name(index: int) -> str:
1414
+ if 0 <= index < len(LABEL_CLASSES):
1415
+ return str(LABEL_CLASSES[index])
1416
+ return f"class_{index}"
1417
+
1418
+
1419
+ def format_predictions(probabilities: np.ndarray) -> pd.DataFrame:
1420
+ rows: List[Dict[str, object]] = []
1421
+ order = np.argsort(probabilities, axis=1)[:, ::-1]
1422
+ for idx, (prob_row, ranking) in enumerate(zip(probabilities, order)):
1423
+ top_idx = int(ranking[0])
1424
+ top_label = label_name(top_idx)
1425
+ top_conf = float(prob_row[top_idx])
1426
+ top3 = [f"{label_name(i)} ({prob_row[i]*100:.2f}%)" for i in ranking[:3]]
1427
+ rows.append(
1428
+ {
1429
+ "window": idx,
1430
+ "predicted_label": top_label,
1431
+ "confidence": round(top_conf, 4),
1432
+ "top3": " | ".join(top3),
1433
+ }
1434
+ )
1435
+ return pd.DataFrame(rows)
1436
+
1437
+
1438
+ def probabilities_to_json(probabilities: np.ndarray) -> List[Dict[str, object]]:
1439
+ payload: List[Dict[str, object]] = []
1440
+ for idx, prob_row in enumerate(probabilities):
1441
+ payload.append(
1442
+ {
1443
+ "window": int(idx),
1444
+ "probabilities": {
1445
+ label_name(i): float(prob_row[i]) for i in range(prob_row.shape[0])
1446
+ },
1447
+ }
1448
+ )
1449
+ return payload
1450
+
1451
+
1452
+ def predict_sequences(
1453
+ sequences: np.ndarray,
1454
+ ) -> Tuple[str, pd.DataFrame, List[Dict[str, object]]]:
1455
+ ensure_ready()
1456
+ sequences = apply_scaler(sequences.astype(np.float32))
1457
+ if MODEL_TYPE == "svm":
1458
+ flattened = sequences.reshape(sequences.shape[0], -1)
1459
+ if hasattr(MODEL, "predict_proba"):
1460
+ probs = MODEL.predict_proba(flattened)
1461
+ else:
1462
+ raise RuntimeError(
1463
+ "Loaded SVM model does not expose predict_proba. Retrain with probability=True."
1464
+ )
1465
+ else:
1466
+ probs = MODEL.predict(sequences, verbose=0)
1467
+ table = format_predictions(probs)
1468
+ json_probs = probabilities_to_json(probs)
1469
+ architecture = MODEL_TYPE.replace("_", "-").upper()
1470
+ status = f"Generated {len(sequences)} windows. {architecture} model output dimension: {probs.shape[1]}."
1471
+ return status, table, json_probs
1472
+
1473
+
1474
+ def predict_from_text(
1475
+ text: str, sequence_length: int
1476
+ ) -> Tuple[str, pd.DataFrame, List[Dict[str, object]]]:
1477
+ arr = parse_text_features(text)
1478
+ n_features = len(FEATURE_COLUMNS)
1479
+ if arr.size % n_features != 0:
1480
+ raise ValueError(
1481
+ f"The number of values ({arr.size}) is not a multiple of the feature dimension "
1482
+ f"({n_features}). Provide values in groups of {n_features}."
1483
+ )
1484
+ timesteps = arr.size // n_features
1485
+ if timesteps != sequence_length:
1486
+ raise ValueError(
1487
+ f"Detected {timesteps} timesteps which does not match the configured sequence length "
1488
+ f"({sequence_length})."
1489
+ )
1490
+ sequences = arr.reshape(1, sequence_length, n_features)
1491
+ status, table, probs = predict_sequences(sequences)
1492
+ status = f"Single window prediction complete. {status}"
1493
+ return status, table, probs
1494
+
1495
+
1496
+ def predict_from_csv(
1497
+ file_obj, sequence_length: int, stride: int
1498
+ ) -> Tuple[str, pd.DataFrame, List[Dict[str, object]]]:
1499
+ df = load_measurement_csv(file_obj.name)
1500
+ sequences = dataframe_to_sequences(
1501
+ df,
1502
+ sequence_length=sequence_length,
1503
+ stride=stride,
1504
+ feature_columns=FEATURE_COLUMNS,
1505
+ )
1506
+ status, table, probs = predict_sequences(sequences)
1507
+ status = f"CSV processed successfully. Generated {len(sequences)} windows. {status}"
1508
+ return status, table, probs
1509
+
1510
+
1511
+ # --------------------------------------------------------------------------------------
1512
+ # Training helpers
1513
+ # --------------------------------------------------------------------------------------
1514
+
1515
+
1516
+ def classification_report_to_dataframe(report: Dict[str, Any]) -> pd.DataFrame:
1517
+ rows: List[Dict[str, Any]] = []
1518
+ for label, metrics in report.items():
1519
+ if isinstance(metrics, dict):
1520
+ row = {"label": label}
1521
+ for key, value in metrics.items():
1522
+ if key == "support":
1523
+ row[key] = int(value)
1524
+ else:
1525
+ row[key] = round(float(value), 4)
1526
+ rows.append(row)
1527
+ else:
1528
+ rows.append({"label": label, "accuracy": round(float(metrics), 4)})
1529
+ return pd.DataFrame(rows)
1530
+
1531
+
1532
+ def confusion_matrix_to_dataframe(
1533
+ confusion: Sequence[Sequence[float]], labels: Sequence[str]
1534
+ ) -> pd.DataFrame:
1535
+ if not confusion:
1536
+ return pd.DataFrame()
1537
+ df = pd.DataFrame(confusion, index=list(labels), columns=list(labels))
1538
+ df.index.name = "True Label"
1539
+ df.columns.name = "Predicted Label"
1540
+ return df
1541
+
1542
+
1543
+ # --------------------------------------------------------------------------------------
1544
+ # Gradio interface
1545
+ # --------------------------------------------------------------------------------------
1546
+
1547
+
1548
+ def build_interface() -> gr.Blocks:
1549
+ theme = gr.themes.Soft(
1550
+ primary_hue="sky", secondary_hue="blue", neutral_hue="gray"
1551
+ ).set(
1552
+ body_background_fill="#1f1f1f",
1553
+ body_text_color="#f5f5f5",
1554
+ block_background_fill="#262626",
1555
+ block_border_color="#333333",
1556
+ button_primary_background_fill="#5ac8fa",
1557
+ button_primary_background_fill_hover="#48b5eb",
1558
+ button_primary_border_color="#38bdf8",
1559
+ button_primary_text_color="#0f172a",
1560
+ button_secondary_background_fill="#3f3f46",
1561
+ button_secondary_text_color="#f5f5f5",
1562
+ )
1563
+
1564
+ def _normalise_directory_string(value: Optional[Union[str, Path]]) -> str:
1565
+ if value is None:
1566
+ return ""
1567
+ path = Path(value).expanduser()
1568
+ try:
1569
+ return str(path.resolve())
1570
+ except Exception:
1571
+ return str(path)
1572
+
1573
+ with gr.Blocks(
1574
+ title="Fault Classification - PMU Data", theme=theme, css=APP_CSS
1575
+ ) as demo:
1576
+ gr.Markdown("# Fault Classification for PMU & PV Data")
1577
+ gr.Markdown(
1578
+ "🖥️ TensorFlow is locked to CPU execution so the Space can run without CUDA drivers."
1579
+ )
1580
+ if MODEL is None or SCALER is None:
1581
+ gr.Markdown(
1582
+ "⚠️ **Artifacts Missing** — Upload `pmu_cnn_lstm_model.keras`, "
1583
+ "`pmu_feature_scaler.pkl`, and `pmu_metadata.json` to enable inference, "
1584
+ "or configure the Hugging Face Hub environment variables so they can be downloaded."
1585
+ )
1586
+ else:
1587
+ class_count = len(LABEL_CLASSES) if LABEL_CLASSES else "unknown"
1588
+ gr.Markdown(
1589
+ f"Loaded a **{MODEL_TYPE.upper()}** model ({MODEL_FORMAT.upper()}) with "
1590
+ f"{len(FEATURE_COLUMNS)} features, sequence length **{SEQUENCE_LENGTH}**, and "
1591
+ f"{class_count} target classes. Use the tabs below to run inference or fine-tune "
1592
+ "the model with your own CSV files."
1593
+ )
1594
+
1595
+ with gr.Accordion("Feature Reference", open=False):
1596
+ gr.Markdown(
1597
+ f"Each time window expects **{len(FEATURE_COLUMNS)} features** ordered as follows:\n"
1598
+ + "\n".join(f"- {name}" for name in FEATURE_COLUMNS)
1599
+ )
1600
+ gr.Markdown(
1601
+ f"Default training parameters: **sequence length = {SEQUENCE_LENGTH}**, "
1602
+ f"**stride = {DEFAULT_WINDOW_STRIDE}**. Adjust them in the tabs as needed."
1603
+ )
1604
+
1605
+ with gr.Tabs():
1606
+ with gr.Tab("Overview"):
1607
+ gr.Markdown(PROJECT_OVERVIEW_MD)
1608
+ with gr.Tab("Inference"):
1609
+ gr.Markdown("## Run Inference")
1610
+ with gr.Row():
1611
+ file_in = gr.File(label="Upload PMU CSV", file_types=[".csv"])
1612
+ text_in = gr.Textbox(
1613
+ lines=4,
1614
+ label="Or paste a single window (comma separated)",
1615
+ placeholder="49.97772,1.215825E-38,...",
1616
+ )
1617
+
1618
+ with gr.Row():
1619
+ sequence_length_input = gr.Slider(
1620
+ minimum=1,
1621
+ maximum=max(1, SEQUENCE_LENGTH * 2),
1622
+ step=1,
1623
+ value=SEQUENCE_LENGTH,
1624
+ label="Sequence length (timesteps)",
1625
+ )
1626
+ stride_input = gr.Slider(
1627
+ minimum=1,
1628
+ maximum=max(1, SEQUENCE_LENGTH),
1629
+ step=1,
1630
+ value=max(1, DEFAULT_WINDOW_STRIDE),
1631
+ label="CSV window stride",
1632
+ )
1633
+
1634
+ predict_btn = gr.Button("🚀 Run Inference", variant="primary")
1635
+ status_out = gr.Textbox(label="Status", interactive=False)
1636
+ table_out = gr.Dataframe(
1637
+ headers=["window", "predicted_label", "confidence", "top3"],
1638
+ label="Predictions",
1639
+ interactive=False,
1640
+ )
1641
+ probs_out = gr.JSON(label="Per-window probabilities")
1642
+
1643
+ def _run_prediction(file_obj, text, sequence_length, stride):
1644
+ sequence_length = int(sequence_length)
1645
+ stride = int(stride)
1646
+ try:
1647
+ if file_obj is not None:
1648
+ return predict_from_csv(file_obj, sequence_length, stride)
1649
+ if text and text.strip():
1650
+ return predict_from_text(text, sequence_length)
1651
+ return (
1652
+ "Please upload a CSV file or provide feature values.",
1653
+ pd.DataFrame(),
1654
+ [],
1655
+ )
1656
+ except Exception as exc:
1657
+ return f"Prediction failed: {exc}", pd.DataFrame(), []
1658
+
1659
+ predict_btn.click(
1660
+ _run_prediction,
1661
+ inputs=[file_in, text_in, sequence_length_input, stride_input],
1662
+ outputs=[status_out, table_out, probs_out],
1663
+ concurrency_limit=EVENT_CONCURRENCY_LIMIT,
1664
+ )
1665
+
1666
+ with gr.Tab("Training"):
1667
+ gr.Markdown("## Train or Fine-tune the Model")
1668
+ gr.Markdown(
1669
+ "Training data is automatically downloaded from the database. "
1670
+ "Refresh the cache if new files are added upstream."
1671
+ )
1672
+
1673
+ training_files_state = gr.State([])
1674
+ with gr.Row():
1675
+ with gr.Column(scale=3):
1676
+ training_files_summary = gr.Textbox(
1677
+ label="Database training CSVs",
1678
+ value="Training dataset not loaded yet.",
1679
+ lines=4,
1680
+ interactive=False,
1681
+ elem_id="training-files-summary",
1682
+ )
1683
+ with gr.Column(scale=2, min_width=240):
1684
+ dataset_info = gr.Markdown(
1685
+ "No local database CSVs downloaded yet.",
1686
+ )
1687
+ dataset_refresh = gr.Button(
1688
+ "🔄 Reload dataset from database",
1689
+ variant="secondary",
1690
+ )
1691
+ clear_cache_button = gr.Button(
1692
+ "🧹 Clear downloaded cache",
1693
+ variant="secondary",
1694
+ )
1695
+
1696
+ with gr.Accordion("📂 DataBaseBrowser", open=False):
1697
+ gr.Markdown(
1698
+ "Browse the upstream database by date and download only the CSVs you need."
1699
+ )
1700
+ with gr.Row(elem_id="date-browser-row"):
1701
+ with gr.Column(scale=1, elem_classes=["date-browser-column"]):
1702
+ year_selector = gr.Dropdown(label="Year", choices=[])
1703
+ year_download_button = gr.Button(
1704
+ "⬇️ Download year CSVs", variant="secondary"
1705
+ )
1706
+ with gr.Column(scale=1, elem_classes=["date-browser-column"]):
1707
+ month_selector = gr.Dropdown(label="Month", choices=[])
1708
+ month_download_button = gr.Button(
1709
+ "⬇️ Download month CSVs", variant="secondary"
1710
+ )
1711
+ with gr.Column(scale=1, elem_classes=["date-browser-column"]):
1712
+ day_selector = gr.Dropdown(label="Day", choices=[])
1713
+ day_download_button = gr.Button(
1714
+ "⬇️ Download day CSVs", variant="secondary"
1715
+ )
1716
+ with gr.Column(elem_id="available-files-section"):
1717
+ available_files = gr.CheckboxGroup(
1718
+ label="Available CSV files",
1719
+ choices=[],
1720
+ value=[],
1721
+ elem_id="available-files-grid",
1722
+ )
1723
+ download_button = gr.Button(
1724
+ "⬇️ Download selected CSVs",
1725
+ variant="secondary",
1726
+ elem_id="download-selected-button",
1727
+ )
1728
+ repo_status = gr.Markdown(
1729
+ "Click 'Reload dataset from database' to fetch the directory tree."
1730
+ )
1731
+
1732
+ with gr.Row():
1733
+ label_input = gr.Dropdown(
1734
+ value=LABEL_COLUMN,
1735
+ choices=[LABEL_COLUMN],
1736
+ allow_custom_value=True,
1737
+ label="Label column name",
1738
+ )
1739
+ model_selector = gr.Radio(
1740
+ choices=["CNN-LSTM", "TCN", "SVM"],
1741
+ value=(
1742
+ "TCN"
1743
+ if MODEL_TYPE == "tcn"
1744
+ else ("SVM" if MODEL_TYPE == "svm" else "CNN-LSTM")
1745
+ ),
1746
+ label="Model architecture",
1747
+ )
1748
+ sequence_length_train = gr.Slider(
1749
+ minimum=4,
1750
+ maximum=max(32, SEQUENCE_LENGTH * 2),
1751
+ step=1,
1752
+ value=SEQUENCE_LENGTH,
1753
+ label="Sequence length",
1754
+ )
1755
+ stride_train = gr.Slider(
1756
+ minimum=1,
1757
+ maximum=max(32, SEQUENCE_LENGTH * 2),
1758
+ step=1,
1759
+ value=max(1, DEFAULT_WINDOW_STRIDE),
1760
+ label="Stride",
1761
+ )
1762
+
1763
+ model_default = MODEL_FILENAME_BY_TYPE.get(
1764
+ MODEL_TYPE, Path(LOCAL_MODEL_FILE).name
1765
+ )
1766
+
1767
+ with gr.Row():
1768
+ validation_train = gr.Slider(
1769
+ minimum=0.05,
1770
+ maximum=0.4,
1771
+ step=0.05,
1772
+ value=0.2,
1773
+ label="Validation split",
1774
+ )
1775
+ batch_train = gr.Slider(
1776
+ minimum=32,
1777
+ maximum=512,
1778
+ step=32,
1779
+ value=128,
1780
+ label="Batch size",
1781
+ )
1782
+ epochs_train = gr.Slider(
1783
+ minimum=5,
1784
+ maximum=100,
1785
+ step=5,
1786
+ value=50,
1787
+ label="Epochs",
1788
+ )
1789
+
1790
+ directory_choices, directory_default = gather_directory_choices(
1791
+ str(MODEL_OUTPUT_DIR)
1792
+ )
1793
+ artifact_choices, default_artifact = gather_artifact_choices(
1794
+ directory_default
1795
+ )
1796
+
1797
+ with gr.Row():
1798
+ output_directory = gr.Dropdown(
1799
+ value=directory_default,
1800
+ label="Output directory",
1801
+ choices=directory_choices,
1802
+ allow_custom_value=True,
1803
+ )
1804
+ model_name = gr.Textbox(
1805
+ value=model_default,
1806
+ label="Model output filename",
1807
+ )
1808
+ scaler_name = gr.Textbox(
1809
+ value=Path(LOCAL_SCALER_FILE).name,
1810
+ label="Scaler output filename",
1811
+ )
1812
+ metadata_name = gr.Textbox(
1813
+ value=Path(LOCAL_METADATA_FILE).name,
1814
+ label="Metadata output filename",
1815
+ )
1816
+
1817
+ with gr.Row():
1818
+ artifact_browser = gr.Dropdown(
1819
+ label="Saved artifacts in directory",
1820
+ choices=artifact_choices,
1821
+ value=default_artifact,
1822
+ )
1823
+ artifact_download_button = gr.DownloadButton(
1824
+ "⬇️ Download selected artifact",
1825
+ value=default_artifact,
1826
+ visible=bool(default_artifact),
1827
+ variant="secondary",
1828
+ )
1829
+
1830
+ def on_output_directory_change(selected_dir, current_selection):
1831
+ choices, normalised = gather_directory_choices(selected_dir)
1832
+ artifact_options, selected = gather_artifact_choices(
1833
+ normalised, current_selection
1834
+ )
1835
+ return (
1836
+ gr.update(choices=choices, value=normalised),
1837
+ gr.update(choices=artifact_options, value=selected),
1838
+ download_button_state(selected),
1839
+ )
1840
+
1841
+ def on_artifact_change(selected_path):
1842
+ return download_button_state(selected_path)
1843
+
1844
+ output_directory.change(
1845
+ on_output_directory_change,
1846
+ inputs=[output_directory, artifact_browser],
1847
+ outputs=[
1848
+ output_directory,
1849
+ artifact_browser,
1850
+ artifact_download_button,
1851
+ ],
1852
+ concurrency_limit=EVENT_CONCURRENCY_LIMIT,
1853
+ )
1854
+
1855
+ artifact_browser.change(
1856
+ on_artifact_change,
1857
+ inputs=[artifact_browser],
1858
+ outputs=[artifact_download_button],
1859
+ concurrency_limit=EVENT_CONCURRENCY_LIMIT,
1860
+ )
1861
+
1862
+ with gr.Row(elem_id="artifact-download-row"):
1863
+ model_download_button = gr.DownloadButton(
1864
+ "⬇️ Download model file",
1865
+ value=None,
1866
+ visible=False,
1867
+ elem_classes=["artifact-download-button"],
1868
+ )
1869
+ scaler_download_button = gr.DownloadButton(
1870
+ "⬇️ Download scaler file",
1871
+ value=None,
1872
+ visible=False,
1873
+ elem_classes=["artifact-download-button"],
1874
+ )
1875
+ metadata_download_button = gr.DownloadButton(
1876
+ "⬇️ Download metadata file",
1877
+ value=None,
1878
+ visible=False,
1879
+ elem_classes=["artifact-download-button"],
1880
+ )
1881
+ tensorboard_download_button = gr.DownloadButton(
1882
+ "⬇️ Download TensorBoard logs",
1883
+ value=None,
1884
+ visible=False,
1885
+ elem_classes=["artifact-download-button"],
1886
+ )
1887
+
1888
+ model_download_button.file_name = Path(LOCAL_MODEL_FILE).name
1889
+ scaler_download_button.file_name = Path(LOCAL_SCALER_FILE).name
1890
+ metadata_download_button.file_name = Path(LOCAL_METADATA_FILE).name
1891
+ tensorboard_download_button.file_name = "tensorboard_logs.zip"
1892
+
1893
+ tensorboard_toggle = gr.Checkbox(
1894
+ value=True,
1895
+ label="Enable TensorBoard logging (creates downloadable archive)",
1896
+ )
1897
+
1898
+ def _suggest_model_filename(choice: str, current_value: str):
1899
+ choice_key = (choice or "cnn_lstm").lower().replace("-", "_")
1900
+ suggested = MODEL_FILENAME_BY_TYPE.get(
1901
+ choice_key, Path(LOCAL_MODEL_FILE).name
1902
+ )
1903
+ known_defaults = set(MODEL_FILENAME_BY_TYPE.values())
1904
+ current_name = Path(current_value).name if current_value else ""
1905
+ if current_name and current_name not in known_defaults:
1906
+ return gr.update()
1907
+ return gr.update(value=suggested)
1908
+
1909
+ model_selector.change(
1910
+ _suggest_model_filename,
1911
+ inputs=[model_selector, model_name],
1912
+ outputs=model_name,
1913
+ )
1914
+
1915
+ with gr.Row():
1916
+ train_button = gr.Button("🛠️ Start Training", variant="primary")
1917
+ progress_button = gr.Button(
1918
+ "📊 Check Progress", variant="secondary"
1919
+ )
1920
+
1921
+ # Training status display
1922
+ training_status = gr.Textbox(label="Training Status", interactive=False)
1923
+ report_output = gr.Dataframe(
1924
+ label="Classification report", interactive=False
1925
+ )
1926
+ history_output = gr.JSON(label="Training history")
1927
+ confusion_output = gr.Dataframe(
1928
+ label="Confusion matrix", interactive=False
1929
+ )
1930
+
1931
+ # Message area at the bottom for progress updates
1932
+ with gr.Accordion("📋 Progress Messages", open=True):
1933
+ progress_messages = gr.Textbox(
1934
+ label="Training Messages",
1935
+ lines=8,
1936
+ max_lines=20,
1937
+ interactive=False,
1938
+ autoscroll=True,
1939
+ placeholder="Click 'Check Progress' to see training updates...",
1940
+ )
1941
+ with gr.Row():
1942
+ gr.Button("🗑️ Clear Messages", variant="secondary").click(
1943
+ lambda: "", outputs=[progress_messages]
1944
+ )
1945
+
1946
+ def _run_training(
1947
+ file_paths,
1948
+ label_column,
1949
+ model_choice,
1950
+ sequence_length,
1951
+ stride,
1952
+ validation_split,
1953
+ batch_size,
1954
+ epochs,
1955
+ output_dir,
1956
+ model_filename,
1957
+ scaler_filename,
1958
+ metadata_filename,
1959
+ enable_tensorboard,
1960
+ ):
1961
+ base_dir = normalise_output_directory(output_dir)
1962
+ try:
1963
+ base_dir.mkdir(parents=True, exist_ok=True)
1964
+
1965
+ model_path = resolve_output_path(
1966
+ base_dir,
1967
+ model_filename,
1968
+ Path(LOCAL_MODEL_FILE).name,
1969
+ )
1970
+ scaler_path = resolve_output_path(
1971
+ base_dir,
1972
+ scaler_filename,
1973
+ Path(LOCAL_SCALER_FILE).name,
1974
+ )
1975
+ metadata_path = resolve_output_path(
1976
+ base_dir,
1977
+ metadata_filename,
1978
+ Path(LOCAL_METADATA_FILE).name,
1979
+ )
1980
+
1981
+ model_path.parent.mkdir(parents=True, exist_ok=True)
1982
+ scaler_path.parent.mkdir(parents=True, exist_ok=True)
1983
+ metadata_path.parent.mkdir(parents=True, exist_ok=True)
1984
+
1985
+ # Create status file path for progress tracking
1986
+ status_file = model_path.parent / "training_status.txt"
1987
+
1988
+ # Initialize status
1989
+ with open(status_file, "w") as f:
1990
+ f.write("Starting training setup...")
1991
+
1992
+ if not file_paths:
1993
+ raise ValueError(
1994
+ "No training CSVs were found in the database cache. "
1995
+ "Use 'Reload dataset from database' and try again."
1996
+ )
1997
+
1998
+ with open(status_file, "w") as f:
1999
+ f.write("Loading and validating CSV files...")
2000
+
2001
+ available_paths = [
2002
+ path for path in file_paths if Path(path).exists()
2003
+ ]
2004
+ missing_paths = [
2005
+ Path(path).name
2006
+ for path in file_paths
2007
+ if not Path(path).exists()
2008
+ ]
2009
+ if not available_paths:
2010
+ raise ValueError(
2011
+ "Database training dataset is unavailable. Reload the dataset and retry."
2012
+ )
2013
+
2014
+ dfs = [load_measurement_csv(path) for path in available_paths]
2015
+ combined = pd.concat(dfs, ignore_index=True)
2016
+
2017
+ # Validate data size and provide recommendations
2018
+ total_samples = len(combined)
2019
+ if total_samples < 100:
2020
+ print(
2021
+ f"Warning: Only {total_samples} samples. Recommend at least 1000 for good results."
2022
+ )
2023
+ print(
2024
+ "Automatically switching to SVM for small dataset compatibility."
2025
+ )
2026
+ if model_choice in ["cnn_lstm", "tcn"]:
2027
+ model_choice = "svm"
2028
+ print(
2029
+ f"Model type changed to SVM for better small dataset performance."
2030
+ )
2031
+ if total_samples < 10:
2032
+ raise ValueError(
2033
+ f"Insufficient data: {total_samples} samples. Need at least 10 samples for training."
2034
+ )
2035
+
2036
+ label_column = (label_column or LABEL_COLUMN).strip()
2037
+ if not label_column:
2038
+ raise ValueError("Label column name cannot be empty.")
2039
+
2040
+ model_choice = (
2041
+ (model_choice or "CNN-LSTM").lower().replace("-", "_")
2042
+ )
2043
+ if model_choice not in {"cnn_lstm", "tcn", "svm"}:
2044
+ raise ValueError(
2045
+ "Select CNN-LSTM, TCN, or SVM for the model architecture."
2046
+ )
2047
+
2048
+ with open(status_file, "w") as f:
2049
+ f.write(
2050
+ f"Starting {model_choice.upper()} training with {len(combined)} samples..."
2051
+ )
2052
+
2053
+ # Start training
2054
+ result = train_from_dataframe(
2055
+ combined,
2056
+ label_column=label_column,
2057
+ feature_columns=None,
2058
+ sequence_length=int(sequence_length),
2059
+ stride=int(stride),
2060
+ validation_split=float(validation_split),
2061
+ batch_size=int(batch_size),
2062
+ epochs=int(epochs),
2063
+ model_type=model_choice,
2064
+ model_path=model_path,
2065
+ scaler_path=scaler_path,
2066
+ metadata_path=metadata_path,
2067
+ enable_tensorboard=bool(enable_tensorboard),
2068
+ )
2069
+
2070
+ refresh_artifacts(
2071
+ Path(result["model_path"]),
2072
+ Path(result["scaler_path"]),
2073
+ Path(result["metadata_path"]),
2074
+ )
2075
+
2076
+ report_df = classification_report_to_dataframe(
2077
+ result["classification_report"]
2078
+ )
2079
+ confusion_df = confusion_matrix_to_dataframe(
2080
+ result["confusion_matrix"], result["class_names"]
2081
+ )
2082
+ tensorboard_dir = result.get("tensorboard_log_dir")
2083
+ tensorboard_zip = result.get("tensorboard_zip_path")
2084
+
2085
+ architecture = result["model_type"].replace("_", "-").upper()
2086
+ status = (
2087
+ f"Training complete using a {architecture} architecture. "
2088
+ f"{result['num_sequences']} windows derived from "
2089
+ f"{result['num_samples']} rows across {len(available_paths)} file(s)."
2090
+ f" Artifacts saved to:"
2091
+ f"\n• Model: {result['model_path']}\n"
2092
+ f"• Scaler: {result['scaler_path']}\n"
2093
+ f"• Metadata: {result['metadata_path']}"
2094
+ )
2095
+
2096
+ status += f"\nLabel column used: {result.get('label_column', label_column)}"
2097
+
2098
+ if tensorboard_dir:
2099
+ status += (
2100
+ f"\nTensorBoard logs directory: {tensorboard_dir}"
2101
+ f'\nRun `tensorboard --logdir "{tensorboard_dir}"` to inspect the training curves.'
2102
+ "\nDownload the archive below to explore the run offline."
2103
+ )
2104
+
2105
+ if missing_paths:
2106
+ skipped = ", ".join(missing_paths)
2107
+ status = f"⚠️ Skipped missing files: {skipped}\n" + status
2108
+
2109
+ artifact_choices, selected_artifact = gather_artifact_choices(
2110
+ str(base_dir), result["model_path"]
2111
+ )
2112
+
2113
+ return (
2114
+ status,
2115
+ report_df,
2116
+ result["history"],
2117
+ confusion_df,
2118
+ download_button_state(result["model_path"]),
2119
+ download_button_state(result["scaler_path"]),
2120
+ download_button_state(result["metadata_path"]),
2121
+ download_button_state(tensorboard_zip),
2122
+ gr.update(value=result.get("label_column", label_column)),
2123
+ gr.update(
2124
+ choices=artifact_choices, value=selected_artifact
2125
+ ),
2126
+ download_button_state(selected_artifact),
2127
+ )
2128
+ except Exception as exc:
2129
+ artifact_choices, selected_artifact = gather_artifact_choices(
2130
+ str(base_dir)
2131
+ )
2132
+ return (
2133
+ f"Training failed: {exc}",
2134
+ pd.DataFrame(),
2135
+ {},
2136
+ pd.DataFrame(),
2137
+ download_button_state(None),
2138
+ download_button_state(None),
2139
+ download_button_state(None),
2140
+ download_button_state(None),
2141
+ gr.update(),
2142
+ gr.update(
2143
+ choices=artifact_choices, value=selected_artifact
2144
+ ),
2145
+ download_button_state(selected_artifact),
2146
+ )
2147
+
2148
+ def _check_progress(output_dir, model_filename, current_messages):
2149
+ """Check training progress by reading status file and accumulate messages."""
2150
+ model_path = resolve_output_path(
2151
+ output_dir, model_filename, Path(LOCAL_MODEL_FILE).name
2152
+ )
2153
+ status_file = model_path.parent / "training_status.txt"
2154
+ status_message = read_training_status(str(status_file))
2155
+
2156
+ # Add timestamp to the message
2157
+ from datetime import datetime
2158
+
2159
+ timestamp = datetime.now().strftime("%H:%M:%S")
2160
+ new_message = f"[{timestamp}] {status_message}"
2161
+
2162
+ # Accumulate messages, keeping last 50 lines to prevent overflow
2163
+ if current_messages:
2164
+ lines = current_messages.split("\n")
2165
+ lines.append(new_message)
2166
+ # Keep only last 50 lines
2167
+ if len(lines) > 50:
2168
+ lines = lines[-50:]
2169
+ accumulated_messages = "\n".join(lines)
2170
+ else:
2171
+ accumulated_messages = new_message
2172
+
2173
+ return accumulated_messages
2174
+
2175
+ train_button.click(
2176
+ _run_training,
2177
+ inputs=[
2178
+ training_files_state,
2179
+ label_input,
2180
+ model_selector,
2181
+ sequence_length_train,
2182
+ stride_train,
2183
+ validation_train,
2184
+ batch_train,
2185
+ epochs_train,
2186
+ output_directory,
2187
+ model_name,
2188
+ scaler_name,
2189
+ metadata_name,
2190
+ tensorboard_toggle,
2191
+ ],
2192
+ outputs=[
2193
+ training_status,
2194
+ report_output,
2195
+ history_output,
2196
+ confusion_output,
2197
+ model_download_button,
2198
+ scaler_download_button,
2199
+ metadata_download_button,
2200
+ tensorboard_download_button,
2201
+ label_input,
2202
+ artifact_browser,
2203
+ artifact_download_button,
2204
+ ],
2205
+ concurrency_limit=EVENT_CONCURRENCY_LIMIT,
2206
+ )
2207
+
2208
+ progress_button.click(
2209
+ _check_progress,
2210
+ inputs=[output_directory, model_name, progress_messages],
2211
+ outputs=[progress_messages],
2212
+ )
2213
+
2214
+ year_selector.change(
2215
+ on_year_change,
2216
+ inputs=[year_selector],
2217
+ outputs=[
2218
+ month_selector,
2219
+ day_selector,
2220
+ available_files,
2221
+ repo_status,
2222
+ ],
2223
+ concurrency_limit=EVENT_CONCURRENCY_LIMIT,
2224
+ )
2225
+
2226
+ month_selector.change(
2227
+ on_month_change,
2228
+ inputs=[year_selector, month_selector],
2229
+ outputs=[day_selector, available_files, repo_status],
2230
+ concurrency_limit=EVENT_CONCURRENCY_LIMIT,
2231
+ )
2232
+
2233
+ day_selector.change(
2234
+ on_day_change,
2235
+ inputs=[year_selector, month_selector, day_selector],
2236
+ outputs=[available_files, repo_status],
2237
+ concurrency_limit=EVENT_CONCURRENCY_LIMIT,
2238
+ )
2239
+
2240
+ download_button.click(
2241
+ download_selected_files,
2242
+ inputs=[
2243
+ year_selector,
2244
+ month_selector,
2245
+ day_selector,
2246
+ available_files,
2247
+ label_input,
2248
+ ],
2249
+ outputs=[
2250
+ training_files_state,
2251
+ training_files_summary,
2252
+ label_input,
2253
+ dataset_info,
2254
+ available_files,
2255
+ repo_status,
2256
+ ],
2257
+ concurrency_limit=EVENT_CONCURRENCY_LIMIT,
2258
+ )
2259
+
2260
+ year_download_button.click(
2261
+ download_year_bundle,
2262
+ inputs=[year_selector, label_input],
2263
+ outputs=[
2264
+ training_files_state,
2265
+ training_files_summary,
2266
+ label_input,
2267
+ dataset_info,
2268
+ available_files,
2269
+ repo_status,
2270
+ ],
2271
+ concurrency_limit=EVENT_CONCURRENCY_LIMIT,
2272
+ )
2273
+
2274
+ month_download_button.click(
2275
+ download_month_bundle,
2276
+ inputs=[year_selector, month_selector, label_input],
2277
+ outputs=[
2278
+ training_files_state,
2279
+ training_files_summary,
2280
+ label_input,
2281
+ dataset_info,
2282
+ available_files,
2283
+ repo_status,
2284
+ ],
2285
+ concurrency_limit=EVENT_CONCURRENCY_LIMIT,
2286
+ )
2287
+
2288
+ day_download_button.click(
2289
+ download_day_bundle,
2290
+ inputs=[year_selector, month_selector, day_selector, label_input],
2291
+ outputs=[
2292
+ training_files_state,
2293
+ training_files_summary,
2294
+ label_input,
2295
+ dataset_info,
2296
+ available_files,
2297
+ repo_status,
2298
+ ],
2299
+ concurrency_limit=EVENT_CONCURRENCY_LIMIT,
2300
+ )
2301
+
2302
+ def _reload_dataset(current_label):
2303
+ local = load_repository_training_files(
2304
+ current_label, force_refresh=True
2305
+ )
2306
+ remote = refresh_remote_browser(force_refresh=True)
2307
+ return (*local, *remote)
2308
+
2309
+ dataset_refresh.click(
2310
+ _reload_dataset,
2311
+ inputs=[label_input],
2312
+ outputs=[
2313
+ training_files_state,
2314
+ training_files_summary,
2315
+ label_input,
2316
+ dataset_info,
2317
+ year_selector,
2318
+ month_selector,
2319
+ day_selector,
2320
+ available_files,
2321
+ repo_status,
2322
+ ],
2323
+ concurrency_limit=EVENT_CONCURRENCY_LIMIT,
2324
+ )
2325
+
2326
+ clear_cache_button.click(
2327
+ clear_downloaded_cache,
2328
+ inputs=[label_input],
2329
+ outputs=[
2330
+ training_files_state,
2331
+ training_files_summary,
2332
+ label_input,
2333
+ dataset_info,
2334
+ year_selector,
2335
+ month_selector,
2336
+ day_selector,
2337
+ available_files,
2338
+ repo_status,
2339
+ ],
2340
+ concurrency_limit=EVENT_CONCURRENCY_LIMIT,
2341
+ )
2342
+
2343
+ def _initialise_dataset():
2344
+ local = load_repository_training_files(
2345
+ LABEL_COLUMN, force_refresh=False
2346
+ )
2347
+ remote = refresh_remote_browser(force_refresh=False)
2348
+ return (*local, *remote)
2349
+
2350
+ demo.load(
2351
+ _initialise_dataset,
2352
+ inputs=None,
2353
+ outputs=[
2354
+ training_files_state,
2355
+ training_files_summary,
2356
+ label_input,
2357
+ dataset_info,
2358
+ year_selector,
2359
+ month_selector,
2360
+ day_selector,
2361
+ available_files,
2362
+ repo_status,
2363
+ ],
2364
+ queue=False,
2365
+ )
2366
+
2367
+ return demo
2368
+
2369
+
2370
+ # --------------------------------------------------------------------------------------
2371
+ # Launch helpers
2372
+ # --------------------------------------------------------------------------------------
2373
+
2374
+
2375
+ def resolve_server_port() -> int:
2376
+ for env_var in ("PORT", "GRADIO_SERVER_PORT"):
2377
+ value = os.environ.get(env_var)
2378
+ if value:
2379
+ try:
2380
+ return int(value)
2381
+ except ValueError:
2382
+ print(f"Ignoring invalid port value from {env_var}: {value}")
2383
+ return 7860
2384
+
2385
+
2386
+ def main():
2387
+ print("Building Gradio interface...")
2388
+ try:
2389
+ demo = build_interface()
2390
+ print("Interface built successfully")
2391
+ except Exception as e:
2392
+ print(f"Failed to build interface: {e}")
2393
+ import traceback
2394
+
2395
+ traceback.print_exc()
2396
+ return
2397
+
2398
+ print("Setting up queue...")
2399
+ try:
2400
+ demo.queue(max_size=QUEUE_MAX_SIZE)
2401
+ print("Queue configured")
2402
+ except Exception as e:
2403
+ print(f"Failed to configure queue: {e}")
2404
+
2405
+ try:
2406
+ port = resolve_server_port()
2407
+ print(f"Launching Gradio app on port {port}")
2408
+ demo.launch(server_name="0.0.0.0", server_port=port, show_error=True)
2409
+ except OSError as exc:
2410
+ print("Failed to launch on requested port:", exc)
2411
+ try:
2412
+ demo.launch(server_name="0.0.0.0", show_error=True)
2413
+ except Exception as e:
2414
+ print(f"Failed to launch completely: {e}")
2415
+ except Exception as e:
2416
+ print(f"Unexpected launch error: {e}")
2417
+ import traceback
2418
+
2419
+ traceback.print_exc()
2420
+
2421
+
2422
+ if __name__ == "__main__":
2423
+ print("=" * 50)
2424
+ print("PMU Fault Classification App Starting")
2425
+ print(f"Python version: {os.sys.version}")
2426
+ print(f"Working directory: {os.getcwd()}")
2427
+ print(f"HUB_REPO: {HUB_REPO}")
2428
+ print(f"Model available: {MODEL is not None}")
2429
+ print(f"Scaler available: {SCALER is not None}")
2430
+ print("=" * 50)
2431
+ main()
app.py CHANGED
@@ -6,6 +6,7 @@ prediction interface optimised for Hugging Face Spaces deployment. It supports
6
  raw PMU time-series CSV uploads as well as manual comma separated feature
7
  vectors.
8
  """
 
9
  from __future__ import annotations
10
 
11
  import json
@@ -61,6 +62,7 @@ ENV_METADATA_PATH = "PMU_METADATA_PATH"
61
  # Utility functions for loading artifacts
62
  # --------------------------------------------------------------------------------------
63
 
 
64
  def download_from_hub(filename: str) -> Optional[Path]:
65
  if not HUB_REPO or not filename:
66
  return None
@@ -76,7 +78,9 @@ def download_from_hub(filename: str) -> Optional[Path]:
76
  return None
77
 
78
 
79
- def resolve_artifact(local_name: str, env_var: str, hub_filename: str) -> Optional[Path]:
 
 
80
  print(f"Resolving artifact: {local_name}, env: {env_var}, hub: {hub_filename}")
81
  candidates = [Path(local_name)] if local_name else []
82
  if local_name:
@@ -130,14 +134,18 @@ except Exception as e:
130
  MODEL_PATH = None
131
 
132
  try:
133
- SCALER_PATH = resolve_artifact(LOCAL_SCALER_FILE, ENV_SCALER_PATH, HUB_SCALER_FILENAME)
 
 
134
  print(f"Scaler path resolved: {SCALER_PATH}")
135
  except Exception as e:
136
  print(f"Scaler path resolution failed: {e}")
137
  SCALER_PATH = None
138
 
139
  try:
140
- METADATA_PATH = resolve_artifact(LOCAL_METADATA_FILE, ENV_METADATA_PATH, HUB_METADATA_FILENAME)
 
 
141
  print(f"Metadata path resolved: {METADATA_PATH}")
142
  except Exception as e:
143
  print(f"Metadata path resolution failed: {e}")
@@ -156,6 +164,7 @@ QUEUE_MAX_SIZE = 32
156
  # ``concurrency_count`` parameter when enabling Gradio's request queue.
157
  EVENT_CONCURRENCY_LIMIT = 2
158
 
 
159
  def try_load_model(path: Optional[Path], model_type: str, model_format: str):
160
  if not path:
161
  return None
@@ -179,6 +188,7 @@ DEFAULT_WINDOW_STRIDE: int = DEFAULT_STRIDE
179
  MODEL_TYPE: str = "cnn_lstm"
180
  MODEL_FORMAT: str = "keras"
181
 
 
182
  def _model_output_path(filename: str) -> str:
183
  return str(MODEL_OUTPUT_DIR / Path(filename).name)
184
 
@@ -190,10 +200,14 @@ MODEL_FILENAME_BY_TYPE: Dict[str, str] = {
190
  }
191
 
192
  REQUIRED_PMU_COLUMNS: Tuple[str, ...] = tuple(DEFAULT_FEATURE_COLUMNS)
193
- TRAINING_UPLOAD_DIR = Path(os.environ.get("PMU_TRAINING_UPLOAD_DIR", "training_uploads"))
 
 
194
  TRAINING_UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
195
 
196
- TRAINING_DATA_REPO = os.environ.get("PMU_TRAINING_DATA_REPO", "VincentCroft/ThesisModelData")
 
 
197
  TRAINING_DATA_BRANCH = os.environ.get("PMU_TRAINING_DATA_BRANCH", "main")
198
  TRAINING_DATA_DIR = Path(os.environ.get("PMU_TRAINING_DATA_DIR", "training_dataset"))
199
  TRAINING_DATA_DIR.mkdir(parents=True, exist_ok=True)
@@ -208,6 +222,7 @@ APP_CSS = """
208
  flex-direction: column;
209
  gap: 0.75rem;
210
  border-radius: 0.75rem;
 
211
  }
212
 
213
  #available-files-grid {
@@ -220,10 +235,6 @@ APP_CSS = """
220
  min-height: 16rem;
221
  }
222
 
223
- #available-files-section:has(.gradio-loading) {
224
- isolation: isolate;
225
- }
226
-
227
  #available-files-grid .wrap {
228
  display: grid;
229
  grid-template-columns: repeat(4, minmax(0, 1fr));
@@ -261,26 +272,55 @@ APP_CSS = """
261
  white-space: nowrap;
262
  }
263
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264
  #available-files-grid .gradio-loading {
265
  position: absolute;
266
- inset: 0;
267
- width: auto;
268
- height: auto;
269
- min-height: 100%;
 
 
270
  display: flex;
271
  align-items: center;
272
  justify-content: center;
273
- background: rgba(10, 14, 23, 0.72);
274
  border-radius: 0.75rem;
275
- z-index: 10;
276
  padding: 1.5rem;
277
  pointer-events: auto;
278
  }
279
 
 
280
  #available-files-grid .gradio-loading > * {
281
  width: 100%;
282
  }
283
 
 
 
 
 
 
 
284
  #available-files-grid .gradio-loading progress,
285
  #available-files-grid .gradio-loading .progress-bar,
286
  #available-files-grid .gradio-loading .loading-progress,
@@ -291,6 +331,9 @@ APP_CSS = """
291
  max-width: none !important;
292
  }
293
 
 
 
 
294
  #available-files-grid .gradio-loading .status,
295
  #available-files-grid .gradio-loading .message,
296
  #available-files-grid .gradio-loading .label {
@@ -363,7 +406,9 @@ def _github_api_url(path: str) -> str:
363
  return f"{base}?ref={TRAINING_DATA_BRANCH}"
364
 
365
 
366
- def list_remote_directory(path: str = "", *, force_refresh: bool = False) -> List[Dict[str, Any]]:
 
 
367
  key = _github_cache_key(path)
368
  if not force_refresh and key in GITHUB_CONTENT_CACHE:
369
  return GITHUB_CONTENT_CACHE[key]
@@ -377,7 +422,9 @@ def list_remote_directory(path: str = "", *, force_refresh: bool = False) -> Lis
377
 
378
  payload = response.json()
379
  if not isinstance(payload, list):
380
- raise RuntimeError("Unexpected GitHub API payload. Expected a directory listing.")
 
 
381
 
382
  GITHUB_CONTENT_CACHE[key] = payload
383
  return payload
@@ -397,7 +444,9 @@ def list_remote_months(year: str, *, force_refresh: bool = False) -> List[str]:
397
  return sorted(months)
398
 
399
 
400
- def list_remote_days(year: str, month: str, *, force_refresh: bool = False) -> List[str]:
 
 
401
  if not year or not month:
402
  return []
403
  entries = list_remote_directory(f"{year}/{month}", force_refresh=force_refresh)
@@ -405,7 +454,9 @@ def list_remote_days(year: str, month: str, *, force_refresh: bool = False) -> L
405
  return sorted(days)
406
 
407
 
408
- def list_remote_files(year: str, month: str, day: str, *, force_refresh: bool = False) -> List[str]:
 
 
409
  if not year or not month or not day:
410
  return []
411
  entries = list_remote_directory(
@@ -451,7 +502,9 @@ def _normalise_header(name: str) -> str:
451
  return str(name).strip().lower()
452
 
453
 
454
- def guess_label_from_columns(columns: Sequence[str], preferred: Optional[str] = None) -> Optional[str]:
 
 
455
  if not columns:
456
  return preferred
457
 
@@ -488,7 +541,7 @@ def read_training_status(status_file_path: str) -> str:
488
  """Read the current training status from file."""
489
  try:
490
  if Path(status_file_path).exists():
491
- with open(status_file_path, 'r') as f:
492
  return f.read().strip()
493
  except Exception:
494
  pass
@@ -545,11 +598,17 @@ def prepare_training_paths(
545
 
546
  summary = summarise_training_files(valid_paths, notes)
547
  preferred = current_label or LABEL_COLUMN
548
- dropdown_choices = sorted(columns_map.values()) if columns_map else [preferred or LABEL_COLUMN]
 
 
549
  guessed = guess_label_from_columns(dropdown_choices, preferred)
550
  dropdown_value = guessed or preferred or LABEL_COLUMN
551
 
552
- return valid_paths, summary, gr.update(choices=dropdown_choices, value=dropdown_value)
 
 
 
 
553
 
554
 
555
  def append_training_files(new_files, existing_paths: Sequence[str], current_label: str):
@@ -580,9 +639,7 @@ def load_repository_training_files(current_label: str, force_refresh: bool = Fal
580
  break
581
 
582
  csv_paths = sorted(
583
- str(path)
584
- for path in TRAINING_DATA_DIR.rglob("*.csv")
585
- if path.is_file()
586
  )
587
  if not csv_paths:
588
  message = (
@@ -735,7 +792,9 @@ def download_selected_files(
735
  notes: List[str] = []
736
  for filename in filenames:
737
  try:
738
- path = download_repository_file(year or "", month or "", day or "", filename)
 
 
739
  success.append(str(path))
740
  except Exception as exc:
741
  notes.append(f"⚠️ {filename}: {exc}")
@@ -840,9 +899,7 @@ def download_month_bundle(
840
  download_repository_file(year, month, day, filename)
841
  downloaded += 1
842
  except Exception as exc:
843
- notes.append(
844
- f"⚠️ {year}/{month}/{day}/{filename}: {exc}"
845
- )
846
 
847
  local = load_repository_training_files(current_label)
848
  message_lines = []
@@ -910,9 +967,7 @@ def download_year_bundle(year: Optional[str], current_label: str):
910
  download_repository_file(year, month, day, filename)
911
  downloaded += 1
912
  except Exception as exc:
913
- notes.append(
914
- f"⚠️ {year}/{month}/{day}/{filename}: {exc}"
915
- )
916
 
917
  local = load_repository_training_files(current_label)
918
  message_lines = []
@@ -1055,6 +1110,7 @@ def clear_training_files():
1055
  gr.update(value=None),
1056
  )
1057
 
 
1058
  PROJECT_OVERVIEW_MD = """
1059
  ## Project Overview
1060
 
@@ -1128,7 +1184,9 @@ def load_measurement_csv(path: str) -> pd.DataFrame:
1128
  df = None
1129
  for separator in ("\t", ",", ";"):
1130
  try:
1131
- df = pd.read_csv(path, sep=separator, engine="python", encoding="utf-8-sig")
 
 
1132
  break
1133
  except Exception:
1134
  df = None
@@ -1144,13 +1202,19 @@ def load_measurement_csv(path: str) -> pd.DataFrame:
1144
 
1145
  # Check if we have enough data for training
1146
  if len(df) < 100:
1147
- print(f"Warning: Only {len(df)} rows of data. Recommend at least 1000 rows for effective training.")
 
 
1148
 
1149
  # Check for label column
1150
- has_label = any(col.lower() in ['fault', 'label', 'class', 'target'] for col in df.columns)
 
 
1151
  if not has_label:
1152
- print("Warning: No label column found. Adding dummy 'Fault' column with value 'Normal' for all samples.")
1153
- df['Fault'] = 'Normal' # Add dummy label for training
 
 
1154
 
1155
  # Create column mapping - map similar column names to expected format
1156
  column_mapping = {}
@@ -1198,7 +1262,9 @@ def load_measurement_csv(path: str) -> pd.DataFrame:
1198
 
1199
  def apply_metadata(metadata: Dict[str, Any]) -> None:
1200
  global FEATURE_COLUMNS, LABEL_CLASSES, LABEL_COLUMN, SEQUENCE_LENGTH, DEFAULT_WINDOW_STRIDE, MODEL_TYPE, MODEL_FORMAT
1201
- FEATURE_COLUMNS = [str(col) for col in metadata.get("feature_columns", DEFAULT_FEATURE_COLUMNS)]
 
 
1202
  LABEL_CLASSES = [str(label) for label in metadata.get("label_classes", [])]
1203
  LABEL_COLUMN = str(metadata.get("label_column", "Fault"))
1204
  SEQUENCE_LENGTH = int(metadata.get("sequence_length", DEFAULT_SEQUENCE_LENGTH))
@@ -1211,6 +1277,7 @@ def apply_metadata(metadata: Dict[str, Any]) -> None:
1211
 
1212
  apply_metadata(METADATA)
1213
 
 
1214
  def sync_label_classes_from_model(model: Optional[object]) -> None:
1215
  global LABEL_CLASSES
1216
  if model is None:
@@ -1244,7 +1311,9 @@ except Exception as e:
1244
  print(f"Label sync failed: {e}")
1245
 
1246
  print("Application initialization completed.")
1247
- print(f"Ready to start Gradio interface. Model available: {MODEL is not None}, Scaler available: {SCALER is not None}")
 
 
1248
 
1249
 
1250
  def refresh_artifacts(model_path: Path, scaler_path: Path, metadata_path: Path) -> None:
@@ -1258,10 +1327,12 @@ def refresh_artifacts(model_path: Path, scaler_path: Path, metadata_path: Path)
1258
  SCALER = try_load_scaler(scaler_path)
1259
  sync_label_classes_from_model(MODEL)
1260
 
 
1261
  # --------------------------------------------------------------------------------------
1262
  # Pre-processing helpers
1263
  # --------------------------------------------------------------------------------------
1264
 
 
1265
  def ensure_ready():
1266
  if MODEL is None or SCALER is None:
1267
  raise RuntimeError(
@@ -1277,7 +1348,9 @@ def parse_text_features(text: str) -> np.ndarray:
1277
  cleaned = re.sub(r"[;\n\t]+", ",", text.strip())
1278
  arr = np.fromstring(cleaned, sep=",")
1279
  if arr.size == 0:
1280
- raise ValueError("No feature values were parsed. Please enter comma-separated numbers.")
 
 
1281
  return arr.astype(np.float32)
1282
 
1283
 
@@ -1290,13 +1363,18 @@ def apply_scaler(sequences: np.ndarray) -> np.ndarray:
1290
  return scaled.reshape(shape)
1291
 
1292
 
1293
- def make_sliding_windows(data: np.ndarray, sequence_length: int, stride: int) -> np.ndarray:
 
 
1294
  if data.shape[0] < sequence_length:
1295
  raise ValueError(
1296
  f"The dataset contains {data.shape[0]} rows which is less than the requested sequence "
1297
  f"length {sequence_length}. Provide more samples or reduce the sequence length."
1298
  )
1299
- windows = [data[start : start + sequence_length] for start in range(0, data.shape[0] - sequence_length + 1, stride)]
 
 
 
1300
  return np.stack(windows)
1301
 
1302
 
@@ -1363,13 +1441,17 @@ def probabilities_to_json(probabilities: np.ndarray) -> List[Dict[str, object]]:
1363
  payload.append(
1364
  {
1365
  "window": int(idx),
1366
- "probabilities": {label_name(i): float(prob_row[i]) for i in range(prob_row.shape[0])},
 
 
1367
  }
1368
  )
1369
  return payload
1370
 
1371
 
1372
- def predict_sequences(sequences: np.ndarray) -> Tuple[str, pd.DataFrame, List[Dict[str, object]]]:
 
 
1373
  ensure_ready()
1374
  sequences = apply_scaler(sequences.astype(np.float32))
1375
  if MODEL_TYPE == "svm":
@@ -1377,7 +1459,9 @@ def predict_sequences(sequences: np.ndarray) -> Tuple[str, pd.DataFrame, List[Di
1377
  if hasattr(MODEL, "predict_proba"):
1378
  probs = MODEL.predict_proba(flattened)
1379
  else:
1380
- raise RuntimeError("Loaded SVM model does not expose predict_proba. Retrain with probability=True.")
 
 
1381
  else:
1382
  probs = MODEL.predict(sequences, verbose=0)
1383
  table = format_predictions(probs)
@@ -1387,7 +1471,9 @@ def predict_sequences(sequences: np.ndarray) -> Tuple[str, pd.DataFrame, List[Di
1387
  return status, table, json_probs
1388
 
1389
 
1390
- def predict_from_text(text: str, sequence_length: int) -> Tuple[str, pd.DataFrame, List[Dict[str, object]]]:
 
 
1391
  arr = parse_text_features(text)
1392
  n_features = len(FEATURE_COLUMNS)
1393
  if arr.size % n_features != 0:
@@ -1407,7 +1493,9 @@ def predict_from_text(text: str, sequence_length: int) -> Tuple[str, pd.DataFram
1407
  return status, table, probs
1408
 
1409
 
1410
- def predict_from_csv(file_obj, sequence_length: int, stride: int) -> Tuple[str, pd.DataFrame, List[Dict[str, object]]]:
 
 
1411
  df = load_measurement_csv(file_obj.name)
1412
  sequences = dataframe_to_sequences(
1413
  df,
@@ -1441,7 +1529,9 @@ def classification_report_to_dataframe(report: Dict[str, Any]) -> pd.DataFrame:
1441
  return pd.DataFrame(rows)
1442
 
1443
 
1444
- def confusion_matrix_to_dataframe(confusion: Sequence[Sequence[float]], labels: Sequence[str]) -> pd.DataFrame:
 
 
1445
  if not confusion:
1446
  return pd.DataFrame()
1447
  df = pd.DataFrame(confusion, index=list(labels), columns=list(labels))
@@ -1454,8 +1544,11 @@ def confusion_matrix_to_dataframe(confusion: Sequence[Sequence[float]], labels:
1454
  # Gradio interface
1455
  # --------------------------------------------------------------------------------------
1456
 
 
1457
  def build_interface() -> gr.Blocks:
1458
- theme = gr.themes.Soft(primary_hue="sky", secondary_hue="blue", neutral_hue="gray").set(
 
 
1459
  body_background_fill="#1f1f1f",
1460
  body_text_color="#f5f5f5",
1461
  block_background_fill="#262626",
@@ -1477,7 +1570,9 @@ def build_interface() -> gr.Blocks:
1477
  except Exception:
1478
  return str(path)
1479
 
1480
- with gr.Blocks(title="Fault Classification - PMU Data", theme=theme, css=APP_CSS) as demo:
 
 
1481
  gr.Markdown("# Fault Classification for PMU & PV Data")
1482
  gr.Markdown(
1483
  "🖥️ TensorFlow is locked to CPU execution so the Space can run without CUDA drivers."
@@ -1553,7 +1648,11 @@ def build_interface() -> gr.Blocks:
1553
  return predict_from_csv(file_obj, sequence_length, stride)
1554
  if text and text.strip():
1555
  return predict_from_text(text, sequence_length)
1556
- return "Please upload a CSV file or provide feature values.", pd.DataFrame(), []
 
 
 
 
1557
  except Exception as exc:
1558
  return f"Prediction failed: {exc}", pd.DataFrame(), []
1559
 
@@ -1745,7 +1844,11 @@ def build_interface() -> gr.Blocks:
1745
  output_directory.change(
1746
  on_output_directory_change,
1747
  inputs=[output_directory, artifact_browser],
1748
- outputs=[output_directory, artifact_browser, artifact_download_button],
 
 
 
 
1749
  concurrency_limit=EVENT_CONCURRENCY_LIMIT,
1750
  )
1751
 
@@ -1811,13 +1914,19 @@ def build_interface() -> gr.Blocks:
1811
 
1812
  with gr.Row():
1813
  train_button = gr.Button("🛠️ Start Training", variant="primary")
1814
- progress_button = gr.Button("📊 Check Progress", variant="secondary")
 
 
1815
 
1816
  # Training status display
1817
  training_status = gr.Textbox(label="Training Status", interactive=False)
1818
- report_output = gr.Dataframe(label="Classification report", interactive=False)
 
 
1819
  history_output = gr.JSON(label="Training history")
1820
- confusion_output = gr.Dataframe(label="Confusion matrix", interactive=False)
 
 
1821
 
1822
  # Message area at the bottom for progress updates
1823
  with gr.Accordion("📋 Progress Messages", open=True):
@@ -1827,12 +1936,11 @@ def build_interface() -> gr.Blocks:
1827
  max_lines=20,
1828
  interactive=False,
1829
  autoscroll=True,
1830
- placeholder="Click 'Check Progress' to see training updates..."
1831
  )
1832
  with gr.Row():
1833
  gr.Button("🗑️ Clear Messages", variant="secondary").click(
1834
- lambda: "",
1835
- outputs=[progress_messages]
1836
  )
1837
 
1838
  def _run_training(
@@ -1878,7 +1986,7 @@ def build_interface() -> gr.Blocks:
1878
  status_file = model_path.parent / "training_status.txt"
1879
 
1880
  # Initialize status
1881
- with open(status_file, 'w') as f:
1882
  f.write("Starting training setup...")
1883
 
1884
  if not file_paths:
@@ -1887,11 +1995,17 @@ def build_interface() -> gr.Blocks:
1887
  "Use 'Reload dataset from database' and try again."
1888
  )
1889
 
1890
- with open(status_file, 'w') as f:
1891
  f.write("Loading and validating CSV files...")
1892
 
1893
- available_paths = [path for path in file_paths if Path(path).exists()]
1894
- missing_paths = [Path(path).name for path in file_paths if not Path(path).exists()]
 
 
 
 
 
 
1895
  if not available_paths:
1896
  raise ValueError(
1897
  "Database training dataset is unavailable. Reload the dataset and retry."
@@ -1903,24 +2017,38 @@ def build_interface() -> gr.Blocks:
1903
  # Validate data size and provide recommendations
1904
  total_samples = len(combined)
1905
  if total_samples < 100:
1906
- print(f"Warning: Only {total_samples} samples. Recommend at least 1000 for good results.")
1907
- print("Automatically switching to SVM for small dataset compatibility.")
 
 
 
 
1908
  if model_choice in ["cnn_lstm", "tcn"]:
1909
  model_choice = "svm"
1910
- print(f"Model type changed to SVM for better small dataset performance.")
 
 
1911
  if total_samples < 10:
1912
- raise ValueError(f"Insufficient data: {total_samples} samples. Need at least 10 samples for training.")
 
 
1913
 
1914
  label_column = (label_column or LABEL_COLUMN).strip()
1915
  if not label_column:
1916
  raise ValueError("Label column name cannot be empty.")
1917
 
1918
- model_choice = (model_choice or "CNN-LSTM").lower().replace("-", "_")
 
 
1919
  if model_choice not in {"cnn_lstm", "tcn", "svm"}:
1920
- raise ValueError("Select CNN-LSTM, TCN, or SVM for the model architecture.")
 
 
1921
 
1922
- with open(status_file, 'w') as f:
1923
- f.write(f"Starting {model_choice.upper()} training with {len(combined)} samples...")
 
 
1924
 
1925
  # Start training
1926
  result = train_from_dataframe(
@@ -1945,8 +2073,12 @@ def build_interface() -> gr.Blocks:
1945
  Path(result["metadata_path"]),
1946
  )
1947
 
1948
- report_df = classification_report_to_dataframe(result["classification_report"])
1949
- confusion_df = confusion_matrix_to_dataframe(result["confusion_matrix"], result["class_names"])
 
 
 
 
1950
  tensorboard_dir = result.get("tensorboard_log_dir")
1951
  tensorboard_zip = result.get("tensorboard_zip_path")
1952
 
@@ -1966,7 +2098,7 @@ def build_interface() -> gr.Blocks:
1966
  if tensorboard_dir:
1967
  status += (
1968
  f"\nTensorBoard logs directory: {tensorboard_dir}"
1969
- f"\nRun `tensorboard --logdir \"{tensorboard_dir}\"` to inspect the training curves."
1970
  "\nDownload the archive below to explore the run offline."
1971
  )
1972
 
@@ -1988,7 +2120,9 @@ def build_interface() -> gr.Blocks:
1988
  download_button_state(result["metadata_path"]),
1989
  download_button_state(tensorboard_zip),
1990
  gr.update(value=result.get("label_column", label_column)),
1991
- gr.update(choices=artifact_choices, value=selected_artifact),
 
 
1992
  download_button_state(selected_artifact),
1993
  )
1994
  except Exception as exc:
@@ -2005,7 +2139,9 @@ def build_interface() -> gr.Blocks:
2005
  download_button_state(None),
2006
  download_button_state(None),
2007
  gr.update(),
2008
- gr.update(choices=artifact_choices, value=selected_artifact),
 
 
2009
  download_button_state(selected_artifact),
2010
  )
2011
 
@@ -2019,17 +2155,18 @@ def build_interface() -> gr.Blocks:
2019
 
2020
  # Add timestamp to the message
2021
  from datetime import datetime
 
2022
  timestamp = datetime.now().strftime("%H:%M:%S")
2023
  new_message = f"[{timestamp}] {status_message}"
2024
 
2025
  # Accumulate messages, keeping last 50 lines to prevent overflow
2026
  if current_messages:
2027
- lines = current_messages.split('\n')
2028
  lines.append(new_message)
2029
  # Keep only last 50 lines
2030
  if len(lines) > 50:
2031
  lines = lines[-50:]
2032
- accumulated_messages = '\n'.join(lines)
2033
  else:
2034
  accumulated_messages = new_message
2035
 
@@ -2077,7 +2214,12 @@ def build_interface() -> gr.Blocks:
2077
  year_selector.change(
2078
  on_year_change,
2079
  inputs=[year_selector],
2080
- outputs=[month_selector, day_selector, available_files, repo_status],
 
 
 
 
 
2081
  concurrency_limit=EVENT_CONCURRENCY_LIMIT,
2082
  )
2083
 
@@ -2158,7 +2300,9 @@ def build_interface() -> gr.Blocks:
2158
  )
2159
 
2160
  def _reload_dataset(current_label):
2161
- local = load_repository_training_files(current_label, force_refresh=True)
 
 
2162
  remote = refresh_remote_browser(force_refresh=True)
2163
  return (*local, *remote)
2164
 
@@ -2197,7 +2341,9 @@ def build_interface() -> gr.Blocks:
2197
  )
2198
 
2199
  def _initialise_dataset():
2200
- local = load_repository_training_files(LABEL_COLUMN, force_refresh=False)
 
 
2201
  remote = refresh_remote_browser(force_refresh=False)
2202
  return (*local, *remote)
2203
 
@@ -2225,6 +2371,7 @@ def build_interface() -> gr.Blocks:
2225
  # Launch helpers
2226
  # --------------------------------------------------------------------------------------
2227
 
 
2228
  def resolve_server_port() -> int:
2229
  for env_var in ("PORT", "GRADIO_SERVER_PORT"):
2230
  value = os.environ.get(env_var)
@@ -2244,6 +2391,7 @@ def main():
2244
  except Exception as e:
2245
  print(f"Failed to build interface: {e}")
2246
  import traceback
 
2247
  traceback.print_exc()
2248
  return
2249
 
@@ -2267,16 +2415,17 @@ def main():
2267
  except Exception as e:
2268
  print(f"Unexpected launch error: {e}")
2269
  import traceback
 
2270
  traceback.print_exc()
2271
 
2272
 
2273
  if __name__ == "__main__":
2274
- print("="*50)
2275
  print("PMU Fault Classification App Starting")
2276
  print(f"Python version: {os.sys.version}")
2277
  print(f"Working directory: {os.getcwd()}")
2278
  print(f"HUB_REPO: {HUB_REPO}")
2279
  print(f"Model available: {MODEL is not None}")
2280
  print(f"Scaler available: {SCALER is not None}")
2281
- print("="*50)
2282
  main()
 
6
  raw PMU time-series CSV uploads as well as manual comma separated feature
7
  vectors.
8
  """
9
+
10
  from __future__ import annotations
11
 
12
  import json
 
62
  # Utility functions for loading artifacts
63
  # --------------------------------------------------------------------------------------
64
 
65
+
66
  def download_from_hub(filename: str) -> Optional[Path]:
67
  if not HUB_REPO or not filename:
68
  return None
 
78
  return None
79
 
80
 
81
+ def resolve_artifact(
82
+ local_name: str, env_var: str, hub_filename: str
83
+ ) -> Optional[Path]:
84
  print(f"Resolving artifact: {local_name}, env: {env_var}, hub: {hub_filename}")
85
  candidates = [Path(local_name)] if local_name else []
86
  if local_name:
 
134
  MODEL_PATH = None
135
 
136
  try:
137
+ SCALER_PATH = resolve_artifact(
138
+ LOCAL_SCALER_FILE, ENV_SCALER_PATH, HUB_SCALER_FILENAME
139
+ )
140
  print(f"Scaler path resolved: {SCALER_PATH}")
141
  except Exception as e:
142
  print(f"Scaler path resolution failed: {e}")
143
  SCALER_PATH = None
144
 
145
  try:
146
+ METADATA_PATH = resolve_artifact(
147
+ LOCAL_METADATA_FILE, ENV_METADATA_PATH, HUB_METADATA_FILENAME
148
+ )
149
  print(f"Metadata path resolved: {METADATA_PATH}")
150
  except Exception as e:
151
  print(f"Metadata path resolution failed: {e}")
 
164
  # ``concurrency_count`` parameter when enabling Gradio's request queue.
165
  EVENT_CONCURRENCY_LIMIT = 2
166
 
167
+
168
  def try_load_model(path: Optional[Path], model_type: str, model_format: str):
169
  if not path:
170
  return None
 
188
  MODEL_TYPE: str = "cnn_lstm"
189
  MODEL_FORMAT: str = "keras"
190
 
191
+
192
  def _model_output_path(filename: str) -> str:
193
  return str(MODEL_OUTPUT_DIR / Path(filename).name)
194
 
 
200
  }
201
 
202
  REQUIRED_PMU_COLUMNS: Tuple[str, ...] = tuple(DEFAULT_FEATURE_COLUMNS)
203
+ TRAINING_UPLOAD_DIR = Path(
204
+ os.environ.get("PMU_TRAINING_UPLOAD_DIR", "training_uploads")
205
+ )
206
  TRAINING_UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
207
 
208
+ TRAINING_DATA_REPO = os.environ.get(
209
+ "PMU_TRAINING_DATA_REPO", "VincentCroft/ThesisModelData"
210
+ )
211
  TRAINING_DATA_BRANCH = os.environ.get("PMU_TRAINING_DATA_BRANCH", "main")
212
  TRAINING_DATA_DIR = Path(os.environ.get("PMU_TRAINING_DATA_DIR", "training_dataset"))
213
  TRAINING_DATA_DIR.mkdir(parents=True, exist_ok=True)
 
222
  flex-direction: column;
223
  gap: 0.75rem;
224
  border-radius: 0.75rem;
225
+ isolation: isolate;
226
  }
227
 
228
  #available-files-grid {
 
235
  min-height: 16rem;
236
  }
237
 
 
 
 
 
238
  #available-files-grid .wrap {
239
  display: grid;
240
  grid-template-columns: repeat(4, minmax(0, 1fr));
 
272
  white-space: nowrap;
273
  }
274
 
275
+ #available-files-section .gradio-loading,
276
+ #available-files-grid .gradio-loading {
277
+ position: absolute;
278
+ top: 0;
279
+ left: 0;
280
+ right: 0;
281
+ bottom: 0;
282
+ width: 100%;
283
+ height: 100%;
284
+ display: flex;
285
+ align-items: center;
286
+ justify-content: center;
287
+ background: rgba(10, 14, 23, 0.92);
288
+ border-radius: 0.75rem;
289
+ z-index: 999;
290
+ padding: 1.5rem;
291
+ pointer-events: auto;
292
+ }
293
+
294
+ #available-files-section .gradio-loading,
295
  #available-files-grid .gradio-loading {
296
  position: absolute;
297
+ top: 0;
298
+ left: 0;
299
+ right: 0;
300
+ bottom: 0;
301
+ width: 100%;
302
+ height: 100%;
303
  display: flex;
304
  align-items: center;
305
  justify-content: center;
306
+ background: rgba(10, 14, 23, 0.92);
307
  border-radius: 0.75rem;
308
+ z-index: 999;
309
  padding: 1.5rem;
310
  pointer-events: auto;
311
  }
312
 
313
+ #available-files-section .gradio-loading > *,
314
  #available-files-grid .gradio-loading > * {
315
  width: 100%;
316
  }
317
 
318
+ #available-files-section .gradio-loading progress,
319
+ #available-files-section .gradio-loading .progress-bar,
320
+ #available-files-section .gradio-loading .loading-progress,
321
+ #available-files-section .gradio-loading [role="progressbar"],
322
+ #available-files-section .gradio-loading .wrap,
323
+ #available-files-section .gradio-loading .inner,
324
  #available-files-grid .gradio-loading progress,
325
  #available-files-grid .gradio-loading .progress-bar,
326
  #available-files-grid .gradio-loading .loading-progress,
 
331
  max-width: none !important;
332
  }
333
 
334
+ #available-files-section .gradio-loading .status,
335
+ #available-files-section .gradio-loading .message,
336
+ #available-files-section .gradio-loading .label,
337
  #available-files-grid .gradio-loading .status,
338
  #available-files-grid .gradio-loading .message,
339
  #available-files-grid .gradio-loading .label {
 
406
  return f"{base}?ref={TRAINING_DATA_BRANCH}"
407
 
408
 
409
+ def list_remote_directory(
410
+ path: str = "", *, force_refresh: bool = False
411
+ ) -> List[Dict[str, Any]]:
412
  key = _github_cache_key(path)
413
  if not force_refresh and key in GITHUB_CONTENT_CACHE:
414
  return GITHUB_CONTENT_CACHE[key]
 
422
 
423
  payload = response.json()
424
  if not isinstance(payload, list):
425
+ raise RuntimeError(
426
+ "Unexpected GitHub API payload. Expected a directory listing."
427
+ )
428
 
429
  GITHUB_CONTENT_CACHE[key] = payload
430
  return payload
 
444
  return sorted(months)
445
 
446
 
447
+ def list_remote_days(
448
+ year: str, month: str, *, force_refresh: bool = False
449
+ ) -> List[str]:
450
  if not year or not month:
451
  return []
452
  entries = list_remote_directory(f"{year}/{month}", force_refresh=force_refresh)
 
454
  return sorted(days)
455
 
456
 
457
+ def list_remote_files(
458
+ year: str, month: str, day: str, *, force_refresh: bool = False
459
+ ) -> List[str]:
460
  if not year or not month or not day:
461
  return []
462
  entries = list_remote_directory(
 
502
  return str(name).strip().lower()
503
 
504
 
505
+ def guess_label_from_columns(
506
+ columns: Sequence[str], preferred: Optional[str] = None
507
+ ) -> Optional[str]:
508
  if not columns:
509
  return preferred
510
 
 
541
  """Read the current training status from file."""
542
  try:
543
  if Path(status_file_path).exists():
544
+ with open(status_file_path, "r") as f:
545
  return f.read().strip()
546
  except Exception:
547
  pass
 
598
 
599
  summary = summarise_training_files(valid_paths, notes)
600
  preferred = current_label or LABEL_COLUMN
601
+ dropdown_choices = (
602
+ sorted(columns_map.values()) if columns_map else [preferred or LABEL_COLUMN]
603
+ )
604
  guessed = guess_label_from_columns(dropdown_choices, preferred)
605
  dropdown_value = guessed or preferred or LABEL_COLUMN
606
 
607
+ return (
608
+ valid_paths,
609
+ summary,
610
+ gr.update(choices=dropdown_choices, value=dropdown_value),
611
+ )
612
 
613
 
614
  def append_training_files(new_files, existing_paths: Sequence[str], current_label: str):
 
639
  break
640
 
641
  csv_paths = sorted(
642
+ str(path) for path in TRAINING_DATA_DIR.rglob("*.csv") if path.is_file()
 
 
643
  )
644
  if not csv_paths:
645
  message = (
 
792
  notes: List[str] = []
793
  for filename in filenames:
794
  try:
795
+ path = download_repository_file(
796
+ year or "", month or "", day or "", filename
797
+ )
798
  success.append(str(path))
799
  except Exception as exc:
800
  notes.append(f"⚠️ {filename}: {exc}")
 
899
  download_repository_file(year, month, day, filename)
900
  downloaded += 1
901
  except Exception as exc:
902
+ notes.append(f"⚠️ {year}/{month}/{day}/{filename}: {exc}")
 
 
903
 
904
  local = load_repository_training_files(current_label)
905
  message_lines = []
 
967
  download_repository_file(year, month, day, filename)
968
  downloaded += 1
969
  except Exception as exc:
970
+ notes.append(f"⚠️ {year}/{month}/{day}/{filename}: {exc}")
 
 
971
 
972
  local = load_repository_training_files(current_label)
973
  message_lines = []
 
1110
  gr.update(value=None),
1111
  )
1112
 
1113
+
1114
  PROJECT_OVERVIEW_MD = """
1115
  ## Project Overview
1116
 
 
1184
  df = None
1185
  for separator in ("\t", ",", ";"):
1186
  try:
1187
+ df = pd.read_csv(
1188
+ path, sep=separator, engine="python", encoding="utf-8-sig"
1189
+ )
1190
  break
1191
  except Exception:
1192
  df = None
 
1202
 
1203
  # Check if we have enough data for training
1204
  if len(df) < 100:
1205
+ print(
1206
+ f"Warning: Only {len(df)} rows of data. Recommend at least 1000 rows for effective training."
1207
+ )
1208
 
1209
  # Check for label column
1210
+ has_label = any(
1211
+ col.lower() in ["fault", "label", "class", "target"] for col in df.columns
1212
+ )
1213
  if not has_label:
1214
+ print(
1215
+ "Warning: No label column found. Adding dummy 'Fault' column with value 'Normal' for all samples."
1216
+ )
1217
+ df["Fault"] = "Normal" # Add dummy label for training
1218
 
1219
  # Create column mapping - map similar column names to expected format
1220
  column_mapping = {}
 
1262
 
1263
  def apply_metadata(metadata: Dict[str, Any]) -> None:
1264
  global FEATURE_COLUMNS, LABEL_CLASSES, LABEL_COLUMN, SEQUENCE_LENGTH, DEFAULT_WINDOW_STRIDE, MODEL_TYPE, MODEL_FORMAT
1265
+ FEATURE_COLUMNS = [
1266
+ str(col) for col in metadata.get("feature_columns", DEFAULT_FEATURE_COLUMNS)
1267
+ ]
1268
  LABEL_CLASSES = [str(label) for label in metadata.get("label_classes", [])]
1269
  LABEL_COLUMN = str(metadata.get("label_column", "Fault"))
1270
  SEQUENCE_LENGTH = int(metadata.get("sequence_length", DEFAULT_SEQUENCE_LENGTH))
 
1277
 
1278
  apply_metadata(METADATA)
1279
 
1280
+
1281
  def sync_label_classes_from_model(model: Optional[object]) -> None:
1282
  global LABEL_CLASSES
1283
  if model is None:
 
1311
  print(f"Label sync failed: {e}")
1312
 
1313
  print("Application initialization completed.")
1314
+ print(
1315
+ f"Ready to start Gradio interface. Model available: {MODEL is not None}, Scaler available: {SCALER is not None}"
1316
+ )
1317
 
1318
 
1319
  def refresh_artifacts(model_path: Path, scaler_path: Path, metadata_path: Path) -> None:
 
1327
  SCALER = try_load_scaler(scaler_path)
1328
  sync_label_classes_from_model(MODEL)
1329
 
1330
+
1331
  # --------------------------------------------------------------------------------------
1332
  # Pre-processing helpers
1333
  # --------------------------------------------------------------------------------------
1334
 
1335
+
1336
  def ensure_ready():
1337
  if MODEL is None or SCALER is None:
1338
  raise RuntimeError(
 
1348
  cleaned = re.sub(r"[;\n\t]+", ",", text.strip())
1349
  arr = np.fromstring(cleaned, sep=",")
1350
  if arr.size == 0:
1351
+ raise ValueError(
1352
+ "No feature values were parsed. Please enter comma-separated numbers."
1353
+ )
1354
  return arr.astype(np.float32)
1355
 
1356
 
 
1363
  return scaled.reshape(shape)
1364
 
1365
 
1366
+ def make_sliding_windows(
1367
+ data: np.ndarray, sequence_length: int, stride: int
1368
+ ) -> np.ndarray:
1369
  if data.shape[0] < sequence_length:
1370
  raise ValueError(
1371
  f"The dataset contains {data.shape[0]} rows which is less than the requested sequence "
1372
  f"length {sequence_length}. Provide more samples or reduce the sequence length."
1373
  )
1374
+ windows = [
1375
+ data[start : start + sequence_length]
1376
+ for start in range(0, data.shape[0] - sequence_length + 1, stride)
1377
+ ]
1378
  return np.stack(windows)
1379
 
1380
 
 
1441
  payload.append(
1442
  {
1443
  "window": int(idx),
1444
+ "probabilities": {
1445
+ label_name(i): float(prob_row[i]) for i in range(prob_row.shape[0])
1446
+ },
1447
  }
1448
  )
1449
  return payload
1450
 
1451
 
1452
+ def predict_sequences(
1453
+ sequences: np.ndarray,
1454
+ ) -> Tuple[str, pd.DataFrame, List[Dict[str, object]]]:
1455
  ensure_ready()
1456
  sequences = apply_scaler(sequences.astype(np.float32))
1457
  if MODEL_TYPE == "svm":
 
1459
  if hasattr(MODEL, "predict_proba"):
1460
  probs = MODEL.predict_proba(flattened)
1461
  else:
1462
+ raise RuntimeError(
1463
+ "Loaded SVM model does not expose predict_proba. Retrain with probability=True."
1464
+ )
1465
  else:
1466
  probs = MODEL.predict(sequences, verbose=0)
1467
  table = format_predictions(probs)
 
1471
  return status, table, json_probs
1472
 
1473
 
1474
+ def predict_from_text(
1475
+ text: str, sequence_length: int
1476
+ ) -> Tuple[str, pd.DataFrame, List[Dict[str, object]]]:
1477
  arr = parse_text_features(text)
1478
  n_features = len(FEATURE_COLUMNS)
1479
  if arr.size % n_features != 0:
 
1493
  return status, table, probs
1494
 
1495
 
1496
+ def predict_from_csv(
1497
+ file_obj, sequence_length: int, stride: int
1498
+ ) -> Tuple[str, pd.DataFrame, List[Dict[str, object]]]:
1499
  df = load_measurement_csv(file_obj.name)
1500
  sequences = dataframe_to_sequences(
1501
  df,
 
1529
  return pd.DataFrame(rows)
1530
 
1531
 
1532
+ def confusion_matrix_to_dataframe(
1533
+ confusion: Sequence[Sequence[float]], labels: Sequence[str]
1534
+ ) -> pd.DataFrame:
1535
  if not confusion:
1536
  return pd.DataFrame()
1537
  df = pd.DataFrame(confusion, index=list(labels), columns=list(labels))
 
1544
  # Gradio interface
1545
  # --------------------------------------------------------------------------------------
1546
 
1547
+
1548
  def build_interface() -> gr.Blocks:
1549
+ theme = gr.themes.Soft(
1550
+ primary_hue="sky", secondary_hue="blue", neutral_hue="gray"
1551
+ ).set(
1552
  body_background_fill="#1f1f1f",
1553
  body_text_color="#f5f5f5",
1554
  block_background_fill="#262626",
 
1570
  except Exception:
1571
  return str(path)
1572
 
1573
+ with gr.Blocks(
1574
+ title="Fault Classification - PMU Data", theme=theme, css=APP_CSS
1575
+ ) as demo:
1576
  gr.Markdown("# Fault Classification for PMU & PV Data")
1577
  gr.Markdown(
1578
  "🖥️ TensorFlow is locked to CPU execution so the Space can run without CUDA drivers."
 
1648
  return predict_from_csv(file_obj, sequence_length, stride)
1649
  if text and text.strip():
1650
  return predict_from_text(text, sequence_length)
1651
+ return (
1652
+ "Please upload a CSV file or provide feature values.",
1653
+ pd.DataFrame(),
1654
+ [],
1655
+ )
1656
  except Exception as exc:
1657
  return f"Prediction failed: {exc}", pd.DataFrame(), []
1658
 
 
1844
  output_directory.change(
1845
  on_output_directory_change,
1846
  inputs=[output_directory, artifact_browser],
1847
+ outputs=[
1848
+ output_directory,
1849
+ artifact_browser,
1850
+ artifact_download_button,
1851
+ ],
1852
  concurrency_limit=EVENT_CONCURRENCY_LIMIT,
1853
  )
1854
 
 
1914
 
1915
  with gr.Row():
1916
  train_button = gr.Button("🛠️ Start Training", variant="primary")
1917
+ progress_button = gr.Button(
1918
+ "📊 Check Progress", variant="secondary"
1919
+ )
1920
 
1921
  # Training status display
1922
  training_status = gr.Textbox(label="Training Status", interactive=False)
1923
+ report_output = gr.Dataframe(
1924
+ label="Classification report", interactive=False
1925
+ )
1926
  history_output = gr.JSON(label="Training history")
1927
+ confusion_output = gr.Dataframe(
1928
+ label="Confusion matrix", interactive=False
1929
+ )
1930
 
1931
  # Message area at the bottom for progress updates
1932
  with gr.Accordion("📋 Progress Messages", open=True):
 
1936
  max_lines=20,
1937
  interactive=False,
1938
  autoscroll=True,
1939
+ placeholder="Click 'Check Progress' to see training updates...",
1940
  )
1941
  with gr.Row():
1942
  gr.Button("🗑️ Clear Messages", variant="secondary").click(
1943
+ lambda: "", outputs=[progress_messages]
 
1944
  )
1945
 
1946
  def _run_training(
 
1986
  status_file = model_path.parent / "training_status.txt"
1987
 
1988
  # Initialize status
1989
+ with open(status_file, "w") as f:
1990
  f.write("Starting training setup...")
1991
 
1992
  if not file_paths:
 
1995
  "Use 'Reload dataset from database' and try again."
1996
  )
1997
 
1998
+ with open(status_file, "w") as f:
1999
  f.write("Loading and validating CSV files...")
2000
 
2001
+ available_paths = [
2002
+ path for path in file_paths if Path(path).exists()
2003
+ ]
2004
+ missing_paths = [
2005
+ Path(path).name
2006
+ for path in file_paths
2007
+ if not Path(path).exists()
2008
+ ]
2009
  if not available_paths:
2010
  raise ValueError(
2011
  "Database training dataset is unavailable. Reload the dataset and retry."
 
2017
  # Validate data size and provide recommendations
2018
  total_samples = len(combined)
2019
  if total_samples < 100:
2020
+ print(
2021
+ f"Warning: Only {total_samples} samples. Recommend at least 1000 for good results."
2022
+ )
2023
+ print(
2024
+ "Automatically switching to SVM for small dataset compatibility."
2025
+ )
2026
  if model_choice in ["cnn_lstm", "tcn"]:
2027
  model_choice = "svm"
2028
+ print(
2029
+ f"Model type changed to SVM for better small dataset performance."
2030
+ )
2031
  if total_samples < 10:
2032
+ raise ValueError(
2033
+ f"Insufficient data: {total_samples} samples. Need at least 10 samples for training."
2034
+ )
2035
 
2036
  label_column = (label_column or LABEL_COLUMN).strip()
2037
  if not label_column:
2038
  raise ValueError("Label column name cannot be empty.")
2039
 
2040
+ model_choice = (
2041
+ (model_choice or "CNN-LSTM").lower().replace("-", "_")
2042
+ )
2043
  if model_choice not in {"cnn_lstm", "tcn", "svm"}:
2044
+ raise ValueError(
2045
+ "Select CNN-LSTM, TCN, or SVM for the model architecture."
2046
+ )
2047
 
2048
+ with open(status_file, "w") as f:
2049
+ f.write(
2050
+ f"Starting {model_choice.upper()} training with {len(combined)} samples..."
2051
+ )
2052
 
2053
  # Start training
2054
  result = train_from_dataframe(
 
2073
  Path(result["metadata_path"]),
2074
  )
2075
 
2076
+ report_df = classification_report_to_dataframe(
2077
+ result["classification_report"]
2078
+ )
2079
+ confusion_df = confusion_matrix_to_dataframe(
2080
+ result["confusion_matrix"], result["class_names"]
2081
+ )
2082
  tensorboard_dir = result.get("tensorboard_log_dir")
2083
  tensorboard_zip = result.get("tensorboard_zip_path")
2084
 
 
2098
  if tensorboard_dir:
2099
  status += (
2100
  f"\nTensorBoard logs directory: {tensorboard_dir}"
2101
+ f'\nRun `tensorboard --logdir "{tensorboard_dir}"` to inspect the training curves.'
2102
  "\nDownload the archive below to explore the run offline."
2103
  )
2104
 
 
2120
  download_button_state(result["metadata_path"]),
2121
  download_button_state(tensorboard_zip),
2122
  gr.update(value=result.get("label_column", label_column)),
2123
+ gr.update(
2124
+ choices=artifact_choices, value=selected_artifact
2125
+ ),
2126
  download_button_state(selected_artifact),
2127
  )
2128
  except Exception as exc:
 
2139
  download_button_state(None),
2140
  download_button_state(None),
2141
  gr.update(),
2142
+ gr.update(
2143
+ choices=artifact_choices, value=selected_artifact
2144
+ ),
2145
  download_button_state(selected_artifact),
2146
  )
2147
 
 
2155
 
2156
  # Add timestamp to the message
2157
  from datetime import datetime
2158
+
2159
  timestamp = datetime.now().strftime("%H:%M:%S")
2160
  new_message = f"[{timestamp}] {status_message}"
2161
 
2162
  # Accumulate messages, keeping last 50 lines to prevent overflow
2163
  if current_messages:
2164
+ lines = current_messages.split("\n")
2165
  lines.append(new_message)
2166
  # Keep only last 50 lines
2167
  if len(lines) > 50:
2168
  lines = lines[-50:]
2169
+ accumulated_messages = "\n".join(lines)
2170
  else:
2171
  accumulated_messages = new_message
2172
 
 
2214
  year_selector.change(
2215
  on_year_change,
2216
  inputs=[year_selector],
2217
+ outputs=[
2218
+ month_selector,
2219
+ day_selector,
2220
+ available_files,
2221
+ repo_status,
2222
+ ],
2223
  concurrency_limit=EVENT_CONCURRENCY_LIMIT,
2224
  )
2225
 
 
2300
  )
2301
 
2302
  def _reload_dataset(current_label):
2303
+ local = load_repository_training_files(
2304
+ current_label, force_refresh=True
2305
+ )
2306
  remote = refresh_remote_browser(force_refresh=True)
2307
  return (*local, *remote)
2308
 
 
2341
  )
2342
 
2343
  def _initialise_dataset():
2344
+ local = load_repository_training_files(
2345
+ LABEL_COLUMN, force_refresh=False
2346
+ )
2347
  remote = refresh_remote_browser(force_refresh=False)
2348
  return (*local, *remote)
2349
 
 
2371
  # Launch helpers
2372
  # --------------------------------------------------------------------------------------
2373
 
2374
+
2375
  def resolve_server_port() -> int:
2376
  for env_var in ("PORT", "GRADIO_SERVER_PORT"):
2377
  value = os.environ.get(env_var)
 
2391
  except Exception as e:
2392
  print(f"Failed to build interface: {e}")
2393
  import traceback
2394
+
2395
  traceback.print_exc()
2396
  return
2397
 
 
2415
  except Exception as e:
2416
  print(f"Unexpected launch error: {e}")
2417
  import traceback
2418
+
2419
  traceback.print_exc()
2420
 
2421
 
2422
  if __name__ == "__main__":
2423
+ print("=" * 50)
2424
  print("PMU Fault Classification App Starting")
2425
  print(f"Python version: {os.sys.version}")
2426
  print(f"Working directory: {os.getcwd()}")
2427
  print(f"HUB_REPO: {HUB_REPO}")
2428
  print(f"Model available: {MODEL is not None}")
2429
  print(f"Scaler available: {SCALER is not None}")
2430
+ print("=" * 50)
2431
  main()