Files changed (1) hide show
  1. main.py +410 -405
main.py CHANGED
@@ -1,406 +1,411 @@
1
- import torch
2
- import torch.nn as nn
3
- import numpy as np
4
- import joblib
5
- import random
6
- import os
7
- from fastapi import FastAPI
8
- from fastapi.middleware.cors import CORSMiddleware
9
- from pydantic import BaseModel
10
- from contextlib import asynccontextmanager
11
-
12
- # ==========================================
13
- # 1. CORE COMPONENTS (SYNTAX-VALIDATED)
14
- # ==========================================
15
- class Mish(nn.Module):
16
- def forward(self, x):
17
- return x * torch.tanh(nn.functional.softplus(x))
18
-
19
- class FourierFeatureMapping(nn.Module):
20
- def __init__(self, input_dim, mapping_size, scale=10.0):
21
- super().__init__()
22
- self.register_buffer('B', torch.randn(input_dim, mapping_size) * scale)
23
-
24
- def forward(self, x):
25
- proj = 2 * np.pi * (x @ self.B)
26
- return torch.cat([torch.sin(proj), torch.cos(proj)], dim=-1)
27
-
28
- # ==========================================
29
- # 2. AUDIT-COMPLIANT ARCHITECTURES (EXACT TENSOR MATCH)
30
- # ==========================================
31
- class SolarPINN(nn.Module):
32
- """Matches audit: backbone.0/2 + output_layer + physics params (shape [])"""
33
- def __init__(self):
34
- super().__init__()
35
- self.backbone = nn.Sequential(
36
- nn.Linear(4, 128), Mish(),
37
- nn.Linear(128, 128), Mish()
38
- )
39
- self.output_layer = nn.Linear(128, 1)
40
- # Physics parameters required by state_dict (shape [])
41
- self.log_thermal_mass = nn.Parameter(torch.tensor(0.0))
42
- self.log_h_conv = nn.Parameter(torch.tensor(0.0))
43
-
44
- def forward(self, x):
45
- return self.output_layer(self.backbone(x))
46
-
47
- class LoadForecastPINN(nn.Module):
48
- """Matches audit: res_blocks with LayerNorm weights at .1 (shape [128])"""
49
- def __init__(self):
50
- super().__init__()
51
- self.fourier = FourierFeatureMapping(9, 32)
52
- self.input_layer = nn.Linear(64, 128)
53
- self.res_blocks = nn.ModuleList([
54
- nn.Sequential(
55
- nn.Linear(128, 128),
56
- nn.LayerNorm(128), # Critical: Audit shows LayerNorm params
57
- Mish(),
58
- nn.Linear(128, 128)
59
- ) for _ in range(3)
60
- ])
61
- self.output_layer = nn.Linear(128, 1)
62
-
63
- def forward(self, x):
64
- x = self.input_layer(self.fourier(x))
65
- for block in self.res_blocks:
66
- x = x + block(x) # True residual connection per audit
67
- return self.output_layer(x)
68
-
69
- class VoltagePINN(nn.Module):
70
- """Matches audit: network layers + v_bias([1]) + raw_B([])"""
71
- def __init__(self):
72
- super().__init__()
73
- self.fourier = FourierFeatureMapping(7, 32)
74
- self.network = nn.Sequential(
75
- nn.Linear(64, 256), nn.LayerNorm(256), Mish(),
76
- nn.Linear(256, 128), nn.LayerNorm(128), Mish(),
77
- nn.Linear(128, 64), nn.LayerNorm(64), Mish(),
78
- nn.Linear(64, 2)
79
- )
80
- # Audit-required parameters
81
- self.v_bias = nn.Parameter(torch.zeros(1)) # Shape [1]
82
- self.raw_B = nn.Parameter(torch.tensor(0.0)) # Shape []
83
-
84
- def forward(self, x):
85
- return self.network(self.fourier(x))
86
-
87
- class BatteryPINN(nn.Module):
88
- """Matches audit: network.0/2/4 indexing"""
89
- def __init__(self):
90
- super().__init__()
91
- self.fourier = FourierFeatureMapping(5, 12)
92
- self.network = nn.Sequential(
93
- nn.Linear(24, 64), Mish(),
94
- nn.Linear(64, 64), Mish(),
95
- nn.Linear(64, 3)
96
- )
97
-
98
- def forward(self, x):
99
- return self.network(self.fourier(x))
100
- class FrequencyPINN(nn.Module):
101
- """Matches audit: net.0/2/4/6 (NO LayerNorm - pure Linear+Mish)"""
102
- def __init__(self):
103
- super().__init__()
104
- self.fourier = FourierFeatureMapping(4, 32)
105
- self.net = nn.Sequential(
106
- nn.Linear(64, 128), Mish(), # net.0
107
- nn.Linear(128, 128), Mish(), # net.2
108
- nn.Linear(128, 128), Mish(), # net.4
109
- nn.Linear(128, 2) # net.6
110
- )
111
-
112
- def forward(self, x):
113
- return self.net(self.fourier(x))
114
-
115
- # ==========================================
116
- # 3. LIFESPAN: ORIGINAL KEYS + SCALER SAFETY
117
- # ==========================================
118
- ml_assets = {}
119
-
120
- @asynccontextmanager
121
- async def lifespan(app: FastAPI):
122
- try:
123
- # SOLAR MODEL (Key: "solar_model" per initial code)
124
- if os.path.exists("solar_model.pt"):
125
- ckpt = torch.load("solar_model.pt", map_location='cpu')
126
- sd = ckpt['model_state_dict'] if isinstance(ckpt, dict) and 'model_state_dict' in ckpt else ckpt
127
- model = SolarPINN()
128
- model.load_state_dict(sd, strict=True)
129
- ml_assets["solar_model"] = model.eval()
130
- ml_assets["solar_stats"] = {
131
- "irr_mean": 450.0, "irr_std": 250.0,
132
- "temp_mean": 25.0, "temp_std": 10.0,
133
- "prev_mean": 35.0, "prev_std": 15.0
134
- }
135
-
136
- # LOAD MODEL (Key: "l_model")
137
- if os.path.exists("load_model.pt"):
138
- ckpt = torch.load("load_model.pt", map_location='cpu')
139
- sd = ckpt['model_state_dict'] if isinstance(ckpt, dict) and 'model_state_dict' in ckpt else ckpt
140
- model = LoadForecastPINN()
141
- model.load_state_dict(sd, strict=True)
142
- ml_assets["l_model"] = model.eval()
143
- if os.path.exists("Load_stats.joblib"):
144
- ml_assets["l_stats"] = joblib.load("Load_stats.joblib")
145
-
146
- # VOLTAGE MODEL (Key: "v_model")
147
- if os.path.exists("voltage_model_v3.pt"):
148
- ckpt = torch.load("voltage_model_v3.pt", map_location='cpu')
149
- sd = ckpt['model_state_dict'] if isinstance(ckpt, dict) and 'model_state_dict' in ckpt else ckpt model = VoltagePINN()
150
- model.load_state_dict(sd, strict=True)
151
- ml_assets["v_model"] = model.eval()
152
- if os.path.exists("scaling_stats_v3.joblib"):
153
- ml_assets["v_stats"] = joblib.load("scaling_stats_v3.joblib")
154
-
155
- # BATTERY MODEL (Key: "b_model")
156
- if os.path.exists("battery_model.pt"):
157
- ckpt = torch.load("battery_model.pt", map_location='cpu')
158
- sd = ckpt['model_state_dict'] if isinstance(ckpt, dict) and 'model_state_dict' in ckpt else ckpt
159
- model = BatteryPINN()
160
- model.load_state_dict(sd, strict=True)
161
- ml_assets["b_model"] = model.eval()
162
- if os.path.exists("battery_model.joblib"):
163
- ml_assets["b_stats"] = joblib.load("battery_model.joblib")
164
-
165
- # FREQUENCY MODEL (Key: "f_model" + SCALER SAFETY)
166
- if os.path.exists("DECODE_Frequency_Twin.pth"):
167
- ckpt = torch.load("DECODE_Frequency_Twin.pth", map_location='cpu')
168
- sd = ckpt['model_state_dict'] if isinstance(ckpt, dict) and 'model_state_dict' in ckpt else ckpt
169
- model = FrequencyPINN()
170
- model.load_state_dict(sd, strict=True)
171
- ml_assets["f_model"] = model.eval()
172
- # CRITICAL: Load actual MinMaxScaler per audit metadata
173
- if os.path.exists("decode_scaler.joblib"):
174
- try:
175
- ml_assets["f_scaler"] = joblib.load("decode_scaler.joblib")
176
- except:
177
- ml_assets["f_scaler"] = None
178
- else:
179
- ml_assets["f_scaler"] = None
180
-
181
- yield
182
- finally:
183
- ml_assets.clear()
184
-
185
- # ==========================================
186
- # 4. FASTAPI SETUP
187
- # ==========================================
188
- app = FastAPI(title="D.E.C.O.D.E. Unified Digital Twin", lifespan=lifespan)
189
- app.add_middleware(
190
- CORSMiddleware,
191
- allow_origins=["*"],
192
- allow_methods=["*"],
193
- allow_headers=["*"],
194
- )
195
-
196
- # ==========================================
197
- # 5. PHYSICS & SCHEMAS (SYNTAX-CORRECTED)
198
- # ==========================================def get_ocv_soc(voltage: float) -> float:
199
- """Physics-based SOC estimation from OCV"""
200
- return np.interp(voltage, [2.8, 3.4, 3.7, 4.2], [0, 15, 65, 100])
201
-
202
- class SolarData(BaseModel):
203
- irradiance_stream: list[float]
204
- ambient_temp_stream: list[float]
205
- wind_speed_stream: list[float]
206
-
207
- class LoadData(BaseModel): # FIXED: Each field on separate line
208
- temperature_c: float
209
- hour: int # Critical newline separation
210
- month: int # Critical newline separation
211
- wind_mw: float = 0.0
212
- solar_mw: float = 0.0
213
-
214
- class BatteryData(BaseModel):
215
- time_sec: float
216
- current: float
217
- voltage: float
218
- temperature: float
219
- soc_prev: float
220
-
221
- class FreqData(BaseModel):
222
- load_mw: float
223
- wind_mw: float
224
- inertia_h: float
225
- power_imbalance_mw: float
226
-
227
- class GridData(BaseModel):
228
- p_load: float
229
- q_load: float
230
- wind_gen: float
231
- solar_gen: float
232
- hour: int
233
-
234
- # ==========================================
235
- # 6. ENDPOINTS: FALLBACKS + PHYSICS COMPLIANCE
236
- # ==========================================
237
- @app.get("/")
238
- def home():
239
- return {
240
- "status": "Online",
241
- "modules": ["Voltage", "Battery", "Frequency", "Load", "Solar"],
242
- "audit_compliant": True,
243
- "strict_loading": True
244
- }
245
-
246
- @app.post("/predict/solar")
247
- def predict_solar(data: SolarData): # CORRECT PARAMETER NAME """Sequential state simulation @ dt=900s with thermal clamping"""
248
- simulation = []
249
- # Fallback: Return empty simulation if model missing (per initial code)
250
- if "solar_model" in ml_assets and "solar_stats" in ml_assets:
251
- stats = ml_assets["solar_stats"]
252
- # PHYSICS CONSTRAINT: Initial state = ambient + 5.0°C (audit training protocol)
253
- curr_temp = data.ambient_temp_stream[0] + 5.0
254
-
255
- with torch.no_grad():
256
- for i in range(len(data.irradiance_stream)):
257
- # AUDIT CONSTRAINT: Wind scaled by 10.0 per training protocol
258
- x = torch.tensor([[
259
- (data.irradiance_stream[i] - stats["irr_mean"]) / stats["irr_std"],
260
- (data.ambient_temp_stream[i] - stats["temp_mean"]) / stats["temp_std"],
261
- data.wind_speed_stream[i] / 10.0, # Critical scaling per audit
262
- (curr_temp - stats["prev_mean"]) / stats["prev_std"]
263
- ]], dtype=torch.float32)
264
-
265
- # PHYSICAL CLAMPING: Prevent thermal runaway (10°C-75°C)
266
- next_temp = ml_assets["solar_model"](x).item()
267
- next_temp = max(10.0, min(75.0, next_temp))
268
-
269
- # Temperature-dependent efficiency
270
- eff = 0.20 * (1 - 0.004 * (next_temp - 25.0))
271
- power_mw = (5000 * data.irradiance_stream[i] * max(0, eff)) / 1e6
272
-
273
- simulation.append({
274
- "module_temp_c": round(next_temp, 2),
275
- "power_mw": round(power_mw, 4)
276
- })
277
- curr_temp = next_temp # SEQUENTIAL STATE FEEDBACK (dt=900s)
278
- return {"simulation": simulation}
279
-
280
- @app.post("/predict/load")
281
- def predict_load(data: LoadData): # CORRECT PARAMETER NAME
282
- """Z-score clamped prediction to prevent Inverted Load Paradox"""
283
- stats = ml_assets.get("l_stats", {})
284
- # PHYSICS CONSTRAINT: Hard Z-score clamping at ±3 (Fourier stability)
285
- t_norm = (data.temperature_c - stats.get('temp_mean', 15.38)) / (stats.get('temp_std', 4.12) + 1e-6)
286
- t_norm = max(-3.0, min(3.0, t_norm))
287
-
288
- # Construct features per audit metadata order
289
- x = torch.tensor([[
290
- t_norm,
291
- max(0, data.temperature_c - 18) / 10,
292
- max(0, 18 - data.temperature_c) / 10,
293
- np.sin(2 * np.pi * data.hour / 24),
294
- np.cos(2 * np.pi * data.hour / 24),
295
- np.sin(2 * np.pi * data.month / 12),
296
- np.cos(2 * np.pi * data.month / 12), data.wind_mw / 10000,
297
- data.solar_mw / 10000
298
- ]], dtype=torch.float32)
299
-
300
- # Fallback base load if model/stats missing
301
- base_load = stats.get('load_mean', 35000.0)
302
- if "l_model" in ml_assets:
303
- with torch.no_grad():
304
- pred = ml_assets["l_model"](x).item()
305
- load_mw = pred * stats.get('load_std', 9773.80) + base_load
306
- else:
307
- load_mw = base_load
308
-
309
- # PHYSICAL SAFETY CORRECTION (SYNTAX FIXED)
310
- if data.temperature_c > 32:
311
- load_mw = max(load_mw, 45000 + (data.temperature_c - 32) * 1200)
312
- elif data.temperature_c < 5:
313
- load_mw = max(load_mw, 42000 + (5 - data.temperature_c) * 900) # Fixed parenthesis
314
-
315
- status = "Peak" if load_mw > 58000 else "Normal"
316
- return {"predicted_load_mw": round(float(load_mw), 2), "status": status}
317
-
318
- @app.post("/predict/battery")
319
- def predict_battery(data: BatteryData): # CORRECT PARAMETER NAME
320
- """Feature engineering: Power product (V*I) required per audit"""
321
- # Physics-based SOC fallback
322
- soc = get_ocv_soc(data.voltage)
323
- temp_c = 25.0 # Fallback temperature if model missing
324
-
325
- if "b_model" in ml_assets and "b_stats" in ml_assets:
326
- stats = ml_assets["b_stats"].get('stats', ml_assets["b_stats"])
327
- # AUDIT CONSTRAINT: Power product feature engineering
328
- power_product = data.voltage * data.current
329
- features = np.array([
330
- data.time_sec,
331
- data.current,
332
- data.voltage,
333
- power_product, # Critical engineered feature
334
- data.soc_prev
335
- ])
336
-
337
- x_scaled = (features - stats['feature_mean']) / (stats['feature_std'] + 1e-6)
338
- with torch.no_grad():
339
- preds = ml_assets["b_model"](torch.tensor([x_scaled], dtype=torch.float32)).numpy()[0]
340
- # Only temperature prediction used (index 1 per audit target order)
341
- temp_c = preds[1] * stats['target_std'][1] + stats['target_mean'][1]
342
-
343
- status = "Normal" if temp_c < 45 else "Overheating"
344
- return {
345
- "soc": round(float(soc), 2), "temp_c": round(float(temp_c), 2),
346
- "status": status
347
- }
348
-
349
- @app.post("/predict/frequency")
350
- def predict_frequency(data: FreqData): # CORRECT PARAMETER NAME
351
- """Hybrid physics + AI with MinMaxScaler compliance"""
352
- # Physics calculation (always available)
353
- f_nom = 60.0
354
- H = max(1.0, data.inertia_h)
355
- rocof = -1 * (data.power_imbalance_mw / 1000.0) / (2 * H)
356
- f_phys = f_nom + (rocof * 2.0)
357
-
358
- # AI prediction ONLY if scaler available (audit requires MinMaxScaler)
359
- f_ai = 60.0
360
- if "f_model" in ml_assets and "f_scaler" in ml_assets and ml_assets["f_scaler"] is not None:
361
- try:
362
- # AUDIT CONSTRAINT: Use actual MinMaxScaler transform
363
- x = np.array([[data.load_mw, data.wind_mw, data.load_mw - data.wind_mw, data.power_imbalance_mw]])
364
- x_scaled = ml_assets["f_scaler"].transform(x)
365
- with torch.no_grad():
366
- pred = ml_assets["f_model"](torch.tensor(x_scaled, dtype=torch.float32)).numpy()[0]
367
- f_ai = 60.0 + pred[0] * 0.5
368
- except:
369
- f_ai = 60.0 # Fallback on scaler error
370
-
371
- # Physics-weighted fusion with hard limits
372
- final_freq = max(58.5, min(61.0, (f_ai * 0.3) + (f_phys * 0.7)))
373
- status = "Stable" if final_freq > 59.6 else "Critical"
374
- return {
375
- "frequency_hz": round(float(final_freq), 4),
376
- "status": status
377
- }
378
-
379
- @app.post("/predict/voltage")
380
- def predict_voltage(data: GridData): # CORRECT PARAMETER NAME
381
- """Model usage with fallback heuristic"""
382
- # Use AI model if artifacts available
383
- if "v_model" in ml_assets and "v_stats" in ml_assets:
384
- stats = ml_assets["v_stats"]
385
- # Construct 7 features per audit input_features order
386
- x_raw = np.array([
387
- data.p_load,
388
- data.q_load,
389
- data.wind_gen,
390
- data.solar_gen,
391
- data.hour,
392
- data.p_load - (data.wind_gen + data.solar_gen), # net load
393
- 0.0 # placeholder for 7th feature (audit shows 7 inputs)
394
- ]) # Z-score scaling per audit metadata
395
- x_norm = (x_raw - stats['x_mean']) / (stats['x_std'] + 1e-6)
396
- with torch.no_grad():
397
- pred = ml_assets["v_model"](torch.tensor([x_norm], dtype=torch.float32)).numpy()[0]
398
- # Denormalize per audit y_mean/y_std
399
- v_mag = pred[0] * stats['y_std'][0] + stats['y_mean'][0]
400
- else:
401
- # Fallback heuristic (original code)
402
- net_load = data.p_load - (data.wind_gen + data.solar_gen)
403
- v_mag = 1.00 - (net_load * 0.000005) + random.uniform(-0.0015, 0.0015)
404
-
405
- status = "Stable" if 0.95 < v_mag < 1.05 else "Critical"
 
 
 
 
 
406
  return {"voltage_pu": round(v_mag, 4), "status": status}
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ import joblib
5
+ import random
6
+ import os
7
+ from fastapi import FastAPI
8
+ from fastapi.middleware.cors import CORSMiddleware
9
+ from pydantic import BaseModel
10
+ from contextlib import asynccontextmanager
11
+
12
+ # ==========================================
13
+ # 1. CORE COMPONENTS (SYNTAX-VALIDATED)
14
+ # ==========================================
15
+ class Mish(nn.Module):
16
+ def forward(self, x):
17
+ return x * torch.tanh(nn.functional.softplus(x))
18
+
19
+ class FourierFeatureMapping(nn.Module):
20
+ def __init__(self, input_dim, mapping_size, scale=10.0):
21
+ super().__init__()
22
+ self.register_buffer('B', torch.randn(input_dim, mapping_size) * scale)
23
+
24
+ def forward(self, x):
25
+ proj = 2 * np.pi * (x @ self.B)
26
+ return torch.cat([torch.sin(proj), torch.cos(proj)], dim=-1)
27
+
28
+ # ==========================================
29
+ # 2. AUDIT-COMPLIANT ARCHITECTURES (EXACT TENSOR MATCH)
30
+ # ==========================================
31
+ class SolarPINN(nn.Module):
32
+ """Matches audit: backbone.0/2 + output_layer + physics params (shape [])"""
33
+ def __init__(self):
34
+ super().__init__()
35
+ self.backbone = nn.Sequential(
36
+ nn.Linear(4, 128), Mish(),
37
+ nn.Linear(128, 128), Mish()
38
+ )
39
+ self.output_layer = nn.Linear(128, 1)
40
+ # Physics parameters required by state_dict (shape [])
41
+ self.log_thermal_mass = nn.Parameter(torch.tensor(0.0))
42
+ self.log_h_conv = nn.Parameter(torch.tensor(0.0))
43
+
44
+ def forward(self, x):
45
+ return self.output_layer(self.backbone(x))
46
+
47
+ class LoadForecastPINN(nn.Module):
48
+ """Matches audit: res_blocks with LayerNorm weights at .1 (shape [128])"""
49
+ def __init__(self):
50
+ super().__init__()
51
+ self.fourier = FourierFeatureMapping(9, 32)
52
+ self.input_layer = nn.Linear(64, 128)
53
+ self.res_blocks = nn.ModuleList([
54
+ nn.Sequential(
55
+ nn.Linear(128, 128),
56
+ nn.LayerNorm(128), # Critical: Audit shows LayerNorm params
57
+ Mish(),
58
+ nn.Linear(128, 128)
59
+ ) for _ in range(3)
60
+ ])
61
+ self.output_layer = nn.Linear(128, 1)
62
+
63
+ def forward(self, x):
64
+ x = self.input_layer(self.fourier(x))
65
+ for block in self.res_blocks:
66
+ x = x + block(x) # True residual connection per audit
67
+ return self.output_layer(x)
68
+
69
+ class VoltagePINN(nn.Module):
70
+ """Matches audit: network layers + v_bias([1]) + raw_B([])"""
71
+ def __init__(self):
72
+ super().__init__()
73
+ self.fourier = FourierFeatureMapping(7, 32)
74
+ self.network = nn.Sequential(
75
+ nn.Linear(64, 256), nn.LayerNorm(256), Mish(),
76
+ nn.Linear(256, 128), nn.LayerNorm(128), Mish(),
77
+ nn.Linear(128, 64), nn.LayerNorm(64), Mish(),
78
+ nn.Linear(64, 2)
79
+ )
80
+ # Audit-required parameters
81
+ self.v_bias = nn.Parameter(torch.zeros(1)) # Shape [1]
82
+ self.raw_B = nn.Parameter(torch.tensor(0.0)) # Shape []
83
+
84
+ def forward(self, x):
85
+ return self.network(self.fourier(x))
86
+
87
+ class BatteryPINN(nn.Module):
88
+ """Matches audit: network.0/2/4 indexing"""
89
+ def __init__(self):
90
+ super().__init__()
91
+ self.fourier = FourierFeatureMapping(5, 12)
92
+ self.network = nn.Sequential(
93
+ nn.Linear(24, 64), Mish(),
94
+ nn.Linear(64, 64), Mish(),
95
+ nn.Linear(64, 3)
96
+ )
97
+
98
+ def forward(self, x):
99
+ return self.network(self.fourier(x))
100
+ class FrequencyPINN(nn.Module):
101
+ """Matches audit: net.0/2/4/6 (NO LayerNorm - pure Linear+Mish)"""
102
+ def __init__(self):
103
+ super().__init__()
104
+ self.fourier = FourierFeatureMapping(4, 32)
105
+ self.net = nn.Sequential(
106
+ nn.Linear(64, 128), Mish(), # net.0
107
+ nn.Linear(128, 128), Mish(), # net.2
108
+ nn.Linear(128, 128), Mish(), # net.4
109
+ nn.Linear(128, 2) # net.6
110
+ )
111
+
112
+ def forward(self, x):
113
+ return self.net(self.fourier(x))
114
+
115
+ # ==========================================
116
+ # 3. LIFESPAN: ORIGINAL KEYS + SCALER SAFETY
117
+ # ==========================================
118
+ ml_assets = {}
119
+
120
+ @asynccontextmanager
121
+ async def lifespan(app: FastAPI):
122
+ try:
123
+ # SOLAR MODEL (Key: "solar_model" per initial code)
124
+ if os.path.exists("solar_model.pt"):
125
+ ckpt = torch.load("solar_model.pt", map_location='cpu')
126
+ sd = ckpt['model_state_dict'] if isinstance(ckpt, dict) and 'model_state_dict' in ckpt else ckpt
127
+ model = SolarPINN()
128
+ model.load_state_dict(sd, strict=True)
129
+ ml_assets["solar_model"] = model.eval()
130
+ ml_assets["solar_stats"] = {
131
+ "irr_mean": 450.0, "irr_std": 250.0,
132
+ "temp_mean": 25.0, "temp_std": 10.0,
133
+ "prev_mean": 35.0, "prev_std": 15.0
134
+ }
135
+
136
+ # LOAD MODEL (Key: "l_model")
137
+ if os.path.exists("load_model.pt"):
138
+ ckpt = torch.load("load_model.pt", map_location='cpu')
139
+ sd = ckpt['model_state_dict'] if isinstance(ckpt, dict) and 'model_state_dict' in ckpt else ckpt
140
+ model = LoadForecastPINN()
141
+ model.load_state_dict(sd, strict=True)
142
+ ml_assets["l_model"] = model.eval()
143
+ if os.path.exists("Load_stats.joblib"):
144
+ ml_assets["l_stats"] = joblib.load("Load_stats.joblib")
145
+
146
+ # VOLTAGE MODEL (Key: "v_model")
147
+ if os.path.exists("voltage_model_v3.pt"):
148
+ ckpt = torch.load("voltage_model_v3.pt", map_location='cpu')
149
+ sd = ckpt['model_state_dict'] if isinstance(ckpt, dict) and 'model_state_dict' in ckpt else ckpt
150
+ model = VoltagePINN()
151
+ model.load_state_dict(sd, strict=True)
152
+ ml_assets["v_model"] = model.eval()
153
+ if os.path.exists("scaling_stats_v3.joblib"):
154
+ ml_assets["v_stats"] = joblib.load("scaling_stats_v3.joblib")
155
+
156
+ # BATTERY MODEL (Key: "b_model")
157
+ if os.path.exists("battery_model.pt"):
158
+ ckpt = torch.load("battery_model.pt", map_location='cpu')
159
+ sd = ckpt['model_state_dict'] if isinstance(ckpt, dict) and 'model_state_dict' in ckpt else ckpt
160
+ model = BatteryPINN()
161
+ model.load_state_dict(sd, strict=True)
162
+ ml_assets["b_model"] = model.eval()
163
+ if os.path.exists("battery_model.joblib"):
164
+ ml_assets["b_stats"] = joblib.load("battery_model.joblib")
165
+
166
+ # FREQUENCY MODEL (Key: "f_model" + SCALER SAFETY)
167
+ if os.path.exists("DECODE_Frequency_Twin.pth"):
168
+ ckpt = torch.load("DECODE_Frequency_Twin.pth", map_location='cpu')
169
+ sd = ckpt['model_state_dict'] if isinstance(ckpt, dict) and 'model_state_dict' in ckpt else ckpt
170
+ model = FrequencyPINN()
171
+ model.load_state_dict(sd, strict=True)
172
+ ml_assets["f_model"] = model.eval()
173
+ # CRITICAL: Load actual MinMaxScaler per audit metadata
174
+ if os.path.exists("decode_scaler.joblib"):
175
+ try:
176
+ ml_assets["f_scaler"] = joblib.load("decode_scaler.joblib")
177
+ except:
178
+ ml_assets["f_scaler"] = None
179
+ else:
180
+ ml_assets["f_scaler"] = None
181
+
182
+ yield
183
+ finally:
184
+ ml_assets.clear()
185
+
186
+ # ==========================================
187
+ # 4. FASTAPI SETUP
188
+ # ==========================================
189
+ app = FastAPI(title="D.E.C.O.D.E. Unified Digital Twin", lifespan=lifespan)
190
+ app.add_middleware(
191
+ CORSMiddleware,
192
+ allow_origins=["*"],
193
+ allow_methods=["*"],
194
+ allow_headers=["*"],
195
+ )
196
+
197
+
198
+ # ==========================================
199
+ # 5. PHYSICS & SCHEMAS (SYNTAX-CORRECTED)
200
+ # ==========================================
201
+ def get_ocv_soc(voltage: float) -> float:
202
+ """Physics-based SOC estimation from OCV"""
203
+ return np.interp(voltage, [2.8, 3.4, 3.7, 4.2], [0, 15, 65, 100])
204
+
205
+ class SolarData(BaseModel):
206
+ irradiance_stream: list[float]
207
+ ambient_temp_stream: list[float]
208
+ wind_speed_stream: list[float]
209
+
210
+ class LoadData(BaseModel): # FIXED: Each field on separate line
211
+ temperature_c: float
212
+ hour: int # Critical newline separation
213
+ month: int # Critical newline separation
214
+ wind_mw: float = 0.0
215
+ solar_mw: float = 0.0
216
+
217
+ class BatteryData(BaseModel):
218
+ time_sec: float
219
+ current: float
220
+ voltage: float
221
+ temperature: float
222
+ soc_prev: float
223
+
224
+ class FreqData(BaseModel):
225
+ load_mw: float
226
+ wind_mw: float
227
+ inertia_h: float
228
+ power_imbalance_mw: float
229
+
230
+ class GridData(BaseModel):
231
+ p_load: float
232
+ q_load: float
233
+ wind_gen: float
234
+ solar_gen: float
235
+ hour: int
236
+
237
+ # ==========================================
238
+ # 6. ENDPOINTS: FALLBACKS + PHYSICS COMPLIANCE
239
+ # ==========================================
240
+ @app.get("/")
241
+ def home():
242
+ return {
243
+ "status": "Online",
244
+ "modules": ["Voltage", "Battery", "Frequency", "Load", "Solar"],
245
+ "audit_compliant": True,
246
+ "strict_loading": True
247
+ }
248
+
249
+ @app.post("/predict/solar")
250
+ def predict_solar(data: SolarData): # CORRECT PARAMETER NAME """Sequential state simulation @ dt=900s with thermal clamping"""
251
+ simulation = []
252
+ # Fallback: Return empty simulation if model missing (per initial code)
253
+ if "solar_model" in ml_assets and "solar_stats" in ml_assets:
254
+ stats = ml_assets["solar_stats"]
255
+ # PHYSICS CONSTRAINT: Initial state = ambient + 5.0°C (audit training protocol)
256
+ curr_temp = data.ambient_temp_stream[0] + 5.0
257
+
258
+ with torch.no_grad():
259
+ for i in range(len(data.irradiance_stream)):
260
+ # AUDIT CONSTRAINT: Wind scaled by 10.0 per training protocol
261
+ x = torch.tensor([[
262
+ (data.irradiance_stream[i] - stats["irr_mean"]) / stats["irr_std"],
263
+ (data.ambient_temp_stream[i] - stats["temp_mean"]) / stats["temp_std"],
264
+ data.wind_speed_stream[i] / 10.0, # Critical scaling per audit
265
+ (curr_temp - stats["prev_mean"]) / stats["prev_std"]
266
+ ]], dtype=torch.float32)
267
+
268
+ # PHYSICAL CLAMPING: Prevent thermal runaway (10°C-75°C)
269
+ next_temp = ml_assets["solar_model"](x).item()
270
+ next_temp = max(10.0, min(75.0, next_temp))
271
+
272
+ # Temperature-dependent efficiency
273
+ eff = 0.20 * (1 - 0.004 * (next_temp - 25.0))
274
+ power_mw = (5000 * data.irradiance_stream[i] * max(0, eff)) / 1e6
275
+
276
+ simulation.append({
277
+ "module_temp_c": round(next_temp, 2),
278
+ "power_mw": round(power_mw, 4)
279
+ })
280
+ curr_temp = next_temp # SEQUENTIAL STATE FEEDBACK (dt=900s)
281
+ return {"simulation": simulation}
282
+
283
+ @app.post("/predict/load")
284
+ def predict_load(data: LoadData): # CORRECT PARAMETER NAME
285
+ """Z-score clamped prediction to prevent Inverted Load Paradox"""
286
+ stats = ml_assets.get("l_stats", {})
287
+ # PHYSICS CONSTRAINT: Hard Z-score clamping at ±3 (Fourier stability)
288
+ t_norm = (data.temperature_c - stats.get('temp_mean', 15.38)) / (stats.get('temp_std', 4.12) + 1e-6)
289
+ t_norm = max(-3.0, min(3.0, t_norm))
290
+
291
+ # Construct features per audit metadata order
292
+ x = torch.tensor([[
293
+ t_norm,
294
+ max(0, data.temperature_c - 18) / 10,
295
+ max(0, 18 - data.temperature_c) / 10,
296
+ np.sin(2 * np.pi * data.hour / 24),
297
+ np.cos(2 * np.pi * data.hour / 24),
298
+ np.sin(2 * np.pi * data.month / 12),
299
+ np.cos(2 * np.pi * data.month / 12),
300
+ data.wind_mw / 10000,
301
+ data.solar_mw / 10000
302
+ ]], dtype=torch.float32)
303
+
304
+ # Fallback base load if model/stats missing
305
+ base_load = stats.get('load_mean', 35000.0)
306
+ if "l_model" in ml_assets:
307
+ with torch.no_grad():
308
+ pred = ml_assets["l_model"](x).item()
309
+ load_mw = pred * stats.get('load_std', 9773.80) + base_load
310
+ else:
311
+ load_mw = base_load
312
+
313
+ # PHYSICAL SAFETY CORRECTION (SYNTAX FIXED)
314
+ if data.temperature_c > 32:
315
+ load_mw = max(load_mw, 45000 + (data.temperature_c - 32) * 1200)
316
+ elif data.temperature_c < 5:
317
+ load_mw = max(load_mw, 42000 + (5 - data.temperature_c) * 900) # Fixed parenthesis
318
+
319
+ status = "Peak" if load_mw > 58000 else "Normal"
320
+ return {"predicted_load_mw": round(float(load_mw), 2), "status": status}
321
+
322
+ @app.post("/predict/battery")
323
+ def predict_battery(data: BatteryData): # CORRECT PARAMETER NAME
324
+ """Feature engineering: Power product (V*I) required per audit"""
325
+ # Physics-based SOC fallback
326
+ soc = get_ocv_soc(data.voltage)
327
+ temp_c = 25.0 # Fallback temperature if model missing
328
+
329
+ if "b_model" in ml_assets and "b_stats" in ml_assets:
330
+ stats = ml_assets["b_stats"].get('stats', ml_assets["b_stats"])
331
+ # AUDIT CONSTRAINT: Power product feature engineering
332
+ power_product = data.voltage * data.current
333
+ features = np.array([
334
+ data.time_sec,
335
+ data.current,
336
+ data.voltage,
337
+ power_product, # Critical engineered feature
338
+ data.soc_prev
339
+ ])
340
+
341
+ x_scaled = (features - stats['feature_mean']) / (stats['feature_std'] + 1e-6)
342
+ with torch.no_grad():
343
+ preds = ml_assets["b_model"](torch.tensor([x_scaled], dtype=torch.float32)).numpy()[0]
344
+ # Only temperature prediction used (index 1 per audit target order)
345
+ temp_c = preds[1] * stats['target_std'][1] + stats['target_mean'][1]
346
+
347
+ status = "Normal" if temp_c < 45 else "Overheating"
348
+ return {
349
+ "soc": round(float(soc), 2), "temp_c": round(float(temp_c), 2),
350
+ "status": status
351
+ }
352
+
353
+ @app.post("/predict/frequency")
354
+ def predict_frequency(data: FreqData): # CORRECT PARAMETER NAME
355
+ """Hybrid physics + AI with MinMaxScaler compliance"""
356
+ # Physics calculation (always available)
357
+ f_nom = 60.0
358
+ H = max(1.0, data.inertia_h)
359
+ rocof = -1 * (data.power_imbalance_mw / 1000.0) / (2 * H)
360
+ f_phys = f_nom + (rocof * 2.0)
361
+
362
+ # AI prediction ONLY if scaler available (audit requires MinMaxScaler)
363
+ f_ai = 60.0
364
+ if "f_model" in ml_assets and "f_scaler" in ml_assets and ml_assets["f_scaler"] is not None:
365
+ try:
366
+ # AUDIT CONSTRAINT: Use actual MinMaxScaler transform
367
+ x = np.array([[data.load_mw, data.wind_mw, data.load_mw - data.wind_mw, data.power_imbalance_mw]])
368
+ x_scaled = ml_assets["f_scaler"].transform(x)
369
+ with torch.no_grad():
370
+ pred = ml_assets["f_model"](torch.tensor(x_scaled, dtype=torch.float32)).numpy()[0]
371
+ f_ai = 60.0 + pred[0] * 0.5
372
+ except:
373
+ f_ai = 60.0 # Fallback on scaler error
374
+
375
+ # Physics-weighted fusion with hard limits
376
+ final_freq = max(58.5, min(61.0, (f_ai * 0.3) + (f_phys * 0.7)))
377
+ status = "Stable" if final_freq > 59.6 else "Critical"
378
+ return {
379
+ "frequency_hz": round(float(final_freq), 4),
380
+ "status": status
381
+ }
382
+
383
+ @app.post("/predict/voltage")
384
+ def predict_voltage(data: GridData): # CORRECT PARAMETER NAME
385
+ """Model usage with fallback heuristic"""
386
+ # Use AI model if artifacts available
387
+ if "v_model" in ml_assets and "v_stats" in ml_assets:
388
+ stats = ml_assets["v_stats"]
389
+ # Construct 7 features per audit input_features order
390
+ x_raw = np.array([
391
+ data.p_load,
392
+ data.q_load,
393
+ data.wind_gen,
394
+ data.solar_gen,
395
+ data.hour,
396
+ data.p_load - (data.wind_gen + data.solar_gen), # net load
397
+ 0.0 # placeholder for 7th feature (audit shows 7 inputs)
398
+ ])
399
+ # Z-score scaling per audit metadata
400
+ x_norm = (x_raw - stats['x_mean']) / (stats['x_std'] + 1e-6)
401
+ with torch.no_grad():
402
+ pred = ml_assets["v_model"](torch.tensor([x_norm], dtype=torch.float32)).numpy()[0]
403
+ # Denormalize per audit y_mean/y_std
404
+ v_mag = pred[0] * stats['y_std'][0] + stats['y_mean'][0]
405
+ else:
406
+ # Fallback heuristic (original code)
407
+ net_load = data.p_load - (data.wind_gen + data.solar_gen)
408
+ v_mag = 1.00 - (net_load * 0.000005) + random.uniform(-0.0015, 0.0015)
409
+
410
+ status = "Stable" if 0.95 < v_mag < 1.05 else "Critical"
411
  return {"voltage_pu": round(v_mag, 4), "status": status}