""" Custom Loss Functions for TFT-ASRO. Implements: - AdaptiveSharpeRatioLoss (ASRO): jointly optimises risk-adjusted return, volatility calibration, and quantile coverage. - MeanAbsoluteDirectionalLoss (MADL): directly optimises directional accuracy weighted by move magnitude. - CombinedQuantileLoss: standard multi-quantile pinball loss used as a component of ASRO and as a standalone baseline. References: - Kisiel & Gorse (2023) "Mean Absolute Directional Loss" (ScienceDirect) - Kisiel & Gorse (2024) "Generalized MADL" (arXiv:2412.18405) """ from __future__ import annotations from typing import Optional, Sequence, Union import torch import torch.nn as nn import numpy as np from deep_learning.config import ASROConfig def debug_asro_loss_direction() -> dict: """ ASRO kayıp fonksiyonunun matematiksel doğrulaması. Üç test senaryosu: 1. correct_direction : tanh(pred) ile actual aynı işaret → loss minimum, Sharpe pozitif 2. anti_direction : tanh(pred) ile actual ters işaret → loss maksimum, Sharpe negatif 3. zero_predictions : model sıfır tahmin üretiyor → Sharpe sıfır (dar varyans tuzağı) Gradyan kontrolleri: - Her senaryoda grad_norm > 0 olmalı (tanh türevi var, sign() yok) - Doğru yönde kayıp < sıfır tahmin < ters yön kaybı sırası bozulmamalı Returns: { "passed": bool, "results": {scenario: {"loss", "grad_norm", "strategy_sharpe"}}, "diagnostics": str # geçti/kaldı açıklaması } """ import torch torch.manual_seed(42) B, T, Q = 64, 5, 7 quantiles = [0.02, 0.10, 0.25, 0.50, 0.75, 0.90, 0.98] actual_std = 0.024 actual = torch.randn(B, T) * actual_std def _make_preds(median: torch.Tensor) -> torch.Tensor: """Build a quantile tensor from a given median, spread ≈ 2*actual_std.""" out = torch.zeros(B, T, Q) for i, q in enumerate(quantiles): out[..., i] = median + (q - 0.5) * actual_std * 2 return out scenarios = { "correct_direction": _make_preds(actual * 0.5), "anti_direction": _make_preds(-actual * 0.5), "zero_predictions": _make_preds(torch.zeros(B, T)), } fn = AdaptiveSharpeRatioLoss(quantiles=quantiles) results: dict = {} for name, preds in scenarios.items(): p = preds.detach().requires_grad_(True) loss_val = fn(p, actual.detach()) loss_val.backward() grad_norm = float(p.grad.norm().item()) if p.grad is not None else 0.0 with torch.no_grad(): med = p.detach()[..., len(quantiles) // 2] signal = torch.tanh(med * 20.0) # same scale as training loss sr = float( (signal * actual).mean() / ((signal * actual).std() + 1e-6) ) results[name] = { "loss": round(float(loss_val.item()), 6), "grad_norm": round(grad_norm, 6), "strategy_sharpe": round(sr, 4), } checks = { "correct < anti loss": results["correct_direction"]["loss"] < results["anti_direction"]["loss"], "correct Sharpe > 0": results["correct_direction"]["strategy_sharpe"] > 0, "anti Sharpe < 0": results["anti_direction"]["strategy_sharpe"] < 0, "gradients non-zero (correct)": results["correct_direction"]["grad_norm"] > 1e-6, "gradients non-zero (anti)": results["anti_direction"]["grad_norm"] > 1e-6, } passed = all(checks.values()) failed = [k for k, v in checks.items() if not v] diagnostics = "ALL CHECKS PASSED" if passed else f"FAILED: {failed}" return {"passed": passed, "results": results, "diagnostics": diagnostics} class MeanAbsoluteDirectionalLoss(nn.Module): """ MADL: Mean Absolute Directional Loss. Directly optimises for directional accuracy weighted by |actual_return|. Correct-direction predictions receive negative loss (reward); incorrect predictions receive positive loss (penalty), scaled by move magnitude. MADL = mean(-sign(pred × actual) × |actual|) Properties: - Large moves dominate the gradient (they're predictable and impactful) - Small/noisy moves contribute almost nothing (correctly ignored) - No binary labels needed (unlike BCE which failed due to ambiguous sigmoid targets for copper's ±0.5% daily returns) - Differentiable via tanh soft-sign approximation """ def __init__(self, tanh_scale: float = 20.0, eps: float = 1e-7): super().__init__() self.tanh_scale = tanh_scale self.eps = eps def forward( self, y_pred_median: torch.Tensor, y_actual: torch.Tensor, ) -> torch.Tensor: """ Args: y_pred_median: (batch, prediction_length) — median predictions y_actual: (batch, prediction_length) — actual returns """ soft_sign = torch.tanh(y_pred_median * self.tanh_scale) direction_match = soft_sign * y_actual madl = (-direction_match * y_actual.abs()).mean() return madl class CombinedQuantileLoss(nn.Module): """ Multi-quantile pinball loss. Given K quantile predictions and actual values, the loss is the average pinball loss across all quantiles and samples. """ def __init__(self, quantiles: Sequence[float] = (0.02, 0.10, 0.25, 0.50, 0.75, 0.90, 0.98)): super().__init__() self.register_buffer( "quantiles", torch.tensor(quantiles, dtype=torch.float32), ) def forward( self, y_pred: torch.Tensor, y_actual: torch.Tensor, ) -> torch.Tensor: """ Args: y_pred: (batch, prediction_length, n_quantiles) y_actual: (batch, prediction_length) """ if y_actual.dim() == 2: y_actual = y_actual.unsqueeze(-1) errors = y_actual - y_pred quantiles = self.quantiles.view(1, 1, -1) loss = torch.max(quantiles * errors, (quantiles - 1) * errors) return loss.mean() def quantile_crossing_penalty(y_pred: torch.Tensor) -> torch.Tensor: """ Penalise adjacent quantile inversions. Quantile heads must be monotonic in the quantile dimension. Pinball loss alone does not guarantee that ordering, especially when q50 also receives directional Sharpe/MADL gradients. """ if y_pred.shape[-1] < 2: return y_pred.new_tensor(0.0) diffs = y_pred[..., 1:] - y_pred[..., :-1] return torch.relu(-diffs).mean() class AdaptiveSharpeRatioLoss(nn.Module): """ TFT-ASRO loss: combines three objectives to break the low-variance trap. L = -Sharpe_component + lambda_vol * volatility_calibration_loss + lambda_quantile * quantile_coverage_loss The Sharpe component incentivises the model to produce directionally correct predictions (not just low MSE), while the volatility term penalises under-estimation of realised variance, and the quantile term ensures proper tail coverage. """ def __init__( self, quantiles: Sequence[float] = (0.02, 0.10, 0.25, 0.50, 0.75, 0.90, 0.98), lambda_vol: float = 0.3, lambda_quantile: float = 0.2, lambda_madl: float = 0.25, lambda_crossing: float = 1.0, risk_free_rate: float = 0.0, sharpe_eps: float = 1e-6, median_idx: Optional[int] = None, ): super().__init__() self.lambda_vol = lambda_vol self.lambda_quantile = lambda_quantile self.lambda_madl = lambda_madl self.lambda_crossing = lambda_crossing self.rf = risk_free_rate self.sharpe_eps = sharpe_eps self.median_idx = median_idx if median_idx is not None else len(quantiles) // 2 self.quantile_loss = CombinedQuantileLoss(quantiles) self.madl = MeanAbsoluteDirectionalLoss() q = list(quantiles) self._q10_idx = q.index(0.10) if 0.10 in q else 1 self._q90_idx = q.index(0.90) if 0.90 in q else len(q) - 2 def forward( self, y_pred: torch.Tensor, y_actual: torch.Tensor, ) -> torch.Tensor: """ Args: y_pred: (batch, prediction_length, n_quantiles) y_actual: (batch, prediction_length) """ median_pred = y_pred[:, :, self.median_idx] y_actual_f = y_actual.float() # --- Directional Sharpe component: sample-level reward --- # Previous batch-level Sharpe (mean/std over 160 shuffled values) produced # noisy gradients that converged to zero-signal equilibrium. The new # formulation has two parts: # # 1. directional_reward = mean(signal_i × actual_i) [sample-level] # Each sample gets a clear gradient: "was your direction correct?" # # 2. risk_normalisation = std(strategy_returns) + eps [batch-level] # Prevents the model from inflating predictions to game the reward. # # Combined: sharpe_loss = -directional_reward / risk_normalisation # This is mathematically equivalent to the Sharpe ratio but the gradient # of mean(signal_i × actual_i) w.r.t. pred_j only depends on sample j, # breaking the "batch-average safe mode" trap. _TANH_SCALE = 20.0 signal = torch.tanh(median_pred * _TANH_SCALE) strategy_returns = signal * y_actual_f - self.rf directional_reward = (signal * y_actual_f).mean() risk_norm = strategy_returns.std() + self.sharpe_eps sharpe_loss = -directional_reward / risk_norm # --- Magnitude-weighted directional bonus --- # BCE on sign labels was counterproductive: sigmoid(±0.005 * 100) ≈ 0.62/0.38 # creates ambiguous targets for the majority of copper daily returns, # causing the model to memorise noise and develop anti-correlation (DA=43.9%). # # Instead, weight each sample's directional contribution by |actual_return|. # Large moves get strong "get the direction right" gradient; # near-zero moves get almost no directional signal (they're unpredictable). abs_actual = y_actual_f.abs() magnitude_weight = abs_actual / (abs_actual.mean() + self.sharpe_eps) weighted_directional = (signal * y_actual_f * magnitude_weight).mean() sharpe_loss = sharpe_loss - 0.3 * weighted_directional # --- Volatility calibration --- # Match Q90-Q10 spread to 2× actual σ so the prediction interval tracks # realised volatility rather than collapsing to a constant. pred_spread = (y_pred[:, :, self._q90_idx] - y_pred[:, :, self._q10_idx]).mean() actual_std = y_actual_f.std() + self.sharpe_eps vol_loss = torch.abs(pred_spread - 2.0 * actual_std) # --- Median amplitude penalty --- # vol_loss only targets the Q10-Q90 band width; the model can widen bands # while keeping median predictions flat. This term directly penalises the # median for having lower variance than actual returns. # relu(1 - VR) fires when pred_std < actual_std; zero otherwise. median_std = median_pred.std() + self.sharpe_eps vr = median_std / actual_std # Two-tier under-variance penalty: # VR < 0.5 → severe (×2.0): model is in flat-prediction mode (VR=0.39 observed) # VR < 1.0 → moderate (×1.0): pred_std is below actual_std but not collapsed # Over-variance penalty unchanged at 1.0× for VR > 1.5. under_severe = 2.0 * torch.relu(0.5 - vr) # fires hard when VR < 0.5 under_moderate = torch.relu(1.0 - vr) # fires when VR < 1.0 over_variance = 1.0 * torch.relu(vr - 1.5) amplitude_loss = under_severe + under_moderate + over_variance # --- Quantile (pinball) loss --- q_loss = self.quantile_loss(y_pred, y_actual) crossing_loss = quantile_crossing_penalty(y_pred) # --- MADL: direct directional accuracy optimisation --- madl_loss = self.madl(median_pred, y_actual_f) # --- Weighted combination --- # calibration = quantile bands + band width + median amplitude # directional = Sharpe (risk-normalised) + MADL (magnitude-weighted) # w_quantile + w_directional = 1.0 w_directional = 1.0 - self.lambda_quantile calibration = ( q_loss + self.lambda_vol * (vol_loss + amplitude_loss) + self.lambda_crossing * crossing_loss ) directional = sharpe_loss + self.lambda_madl * madl_loss total = self.lambda_quantile * calibration + w_directional * directional return total @classmethod def from_config(cls, cfg: ASROConfig, quantiles: Sequence[float]) -> "AdaptiveSharpeRatioLoss": return cls( quantiles=quantiles, lambda_vol=cfg.lambda_vol, lambda_quantile=cfg.lambda_quantile, lambda_madl=cfg.lambda_madl, lambda_crossing=cfg.lambda_crossing, risk_free_rate=cfg.risk_free_rate, )