Spaces:
Running
Running
Upload Quasar_axrvi_ranker.py
Browse files- Quasar_axrvi_ranker.py +159 -126
Quasar_axrvi_ranker.py
CHANGED
|
@@ -3939,10 +3939,6 @@ class QCSAMCrossAssetLayer(nn.Module):
|
|
| 3939 |
"""
|
| 3940 |
Live QCSAM/FABLE cross-asset core β integrates the full quantum pipeline
|
| 3941 |
into the AXRVINet forward pass.
|
| 3942 |
-
"""
|
| 3943 |
-
"""
|
| 3944 |
-
Live QCSAM/FABLE cross-asset core β integrates the full quantum pipeline
|
| 3945 |
-
into the AXRVINet forward pass.
|
| 3946 |
|
| 3947 |
This is the component that turns dead QCSAM/FABLE code into the active
|
| 3948 |
cross-asset interaction engine of the ranker.
|
|
@@ -3973,6 +3969,14 @@ class QCSAMCrossAssetLayer(nn.Module):
|
|
| 3973 |
qffn_output : (B, hilbert_dim) cdouble
|
| 3974 |
readout_input : (B, N, hilbert_dim) float32 (real part broadcast)
|
| 3975 |
layer_output : (B, N, d_model) float32
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3976 |
"""
|
| 3977 |
|
| 3978 |
# Registry flag β set True after first successful forward.
|
|
@@ -4170,7 +4174,7 @@ class AXRVINet(nn.Module):
|
|
| 4170 |
# directly (e.g. test_components) still get the right architecture.
|
| 4171 |
if feature_dim is not None or seq_len is not None:
|
| 4172 |
import copy
|
| 4173 |
-
config =
|
| 4174 |
if feature_dim is not None:
|
| 4175 |
config.feature_dim = feature_dim
|
| 4176 |
if seq_len is not None:
|
|
@@ -5637,7 +5641,9 @@ class HybridTrainer:
|
|
| 5637 |
self.lambda_align= ckpt.get("lambda_align", self.lambda_align) # QCSAM
|
| 5638 |
self.rank_margin = ckpt.get("rank_margin", self.rank_margin)
|
| 5639 |
if "loss_history" in ckpt:
|
| 5640 |
-
|
|
|
|
|
|
|
| 5641 |
logger.info(f"β
Model loaded β {path} | resumed from train_step={self.train_step}")
|
| 5642 |
except FileNotFoundError:
|
| 5643 |
logger.info(f"βΉοΈ No checkpoint at {path} β starting fresh")
|
|
@@ -6677,7 +6683,7 @@ class QuasarAXRVIBridge:
|
|
| 6677 |
hub_ws_url: str = os.environ.get("QUASAR_HUB_URL", "ws://localhost:7860/ws/subscribe"),
|
| 6678 |
enable_logging: bool = True,
|
| 6679 |
checkpoint_dir: str = "./Ranker6", # folder for full-state checkpoints
|
| 6680 |
-
resume: bool =
|
| 6681 |
hf_repo_id: Optional[str] = "KarlQuant/k1rl-checkpoints", # HF Dataset repo
|
| 6682 |
):
|
| 6683 |
self.config = config or AssetRankerConfig()
|
|
@@ -8242,22 +8248,22 @@ class QuasarAXRVIBridge:
|
|
| 8242 |
f"ce={train_result.get('ce', 0):.4f} | "
|
| 8243 |
f"batch_size={len(batch)}"
|
| 8244 |
)
|
| 8245 |
-
# ββ HF SYNC: mirror trainer .pt to HF
|
| 8246 |
-
#
|
| 8247 |
-
#
|
| 8248 |
-
#
|
| 8249 |
-
|
| 8250 |
-
|
| 8251 |
-
|
| 8252 |
-
|
| 8253 |
-
|
| 8254 |
-
|
| 8255 |
-
|
| 8256 |
-
|
| 8257 |
-
|
| 8258 |
-
|
| 8259 |
-
|
| 8260 |
-
|
| 8261 |
else:
|
| 8262 |
logger.warning(
|
| 8263 |
"β οΈ [Training] train_on_batch returned empty result β "
|
|
@@ -8576,6 +8582,11 @@ class HFDatasetCheckpointManager:
|
|
| 8576 |
"""
|
| 8577 |
Upload *local_path* to HF Dataset as step_XXXXXXX.pt.
|
| 8578 |
Updates ranker_index.json and latest sentinel on HF.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8579 |
Returns True on success, False on any error (non-raising).
|
| 8580 |
"""
|
| 8581 |
if not self._ensure_hf():
|
|
@@ -8586,95 +8597,95 @@ class HFDatasetCheckpointManager:
|
|
| 8586 |
logger.warning(f"βοΈ HF upload skipped β file not found: {local_path}")
|
| 8587 |
return False
|
| 8588 |
|
|
|
|
|
|
|
|
|
|
| 8589 |
hf_path = self._hf_ckpt_path(step)
|
| 8590 |
try:
|
| 8591 |
if self.verbose:
|
| 8592 |
-
logger.info(f"βοΈ Uploading {hf_path} β {self.repo_id} β¦")
|
| 8593 |
|
| 8594 |
-
|
| 8595 |
-
|
| 8596 |
-
|
| 8597 |
-
|
| 8598 |
-
|
| 8599 |
-
|
| 8600 |
-
|
| 8601 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8602 |
|
| 8603 |
-
|
| 8604 |
-
|
| 8605 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8606 |
|
| 8607 |
self._upload_count += 1
|
| 8608 |
self._last_upload_step = step
|
| 8609 |
-
logger.info(
|
|
|
|
|
|
|
| 8610 |
return True
|
| 8611 |
|
| 8612 |
except Exception as exc:
|
| 8613 |
logger.warning(f"βοΈ β οΈ HF upload failed (non-fatal): {exc}")
|
| 8614 |
return False
|
| 8615 |
|
| 8616 |
-
|
| 8617 |
-
|
| 8618 |
-
|
| 8619 |
-
|
| 8620 |
-
metadata: Optional[Dict],
|
| 8621 |
-
) -> None:
|
| 8622 |
-
"""Download existing HF index, append entry, re-upload."""
|
| 8623 |
-
import tempfile, json as _json
|
| 8624 |
-
|
| 8625 |
-
existing: Dict[str, Any] = {}
|
| 8626 |
-
try:
|
| 8627 |
-
with tempfile.TemporaryDirectory() as tmpdir:
|
| 8628 |
-
idx_file = self._hf_hub_dl(
|
| 8629 |
-
repo_id=self.repo_id,
|
| 8630 |
-
filename=self.INDEX_FILENAME,
|
| 8631 |
-
repo_type="dataset",
|
| 8632 |
-
token=self.token,
|
| 8633 |
-
local_dir=tmpdir,
|
| 8634 |
-
local_dir_use_symlinks=False,
|
| 8635 |
-
)
|
| 8636 |
-
with open(idx_file) as fh:
|
| 8637 |
-
existing = _json.load(fh)
|
| 8638 |
-
except Exception:
|
| 8639 |
-
pass # no existing index β start fresh
|
| 8640 |
-
|
| 8641 |
-
entry: Dict[str, Any] = {
|
| 8642 |
-
"step": step,
|
| 8643 |
-
"filename": self._hf_ckpt_path(step),
|
| 8644 |
-
"size_mb": round(local_path.stat().st_size / 1_048_576, 2),
|
| 8645 |
-
"timestamp": datetime.now().isoformat(),
|
| 8646 |
-
}
|
| 8647 |
-
if metadata:
|
| 8648 |
-
entry.update(metadata)
|
| 8649 |
-
|
| 8650 |
-
# Preserve local-index schema
|
| 8651 |
-
if "checkpoints" not in existing:
|
| 8652 |
-
existing["checkpoints"] = []
|
| 8653 |
-
existing["checkpoints"] = [
|
| 8654 |
-
cp for cp in existing["checkpoints"] if cp.get("step") != step
|
| 8655 |
-
]
|
| 8656 |
-
existing["checkpoints"].append(entry)
|
| 8657 |
-
existing["checkpoints"].sort(key=lambda x: x.get("step", 0))
|
| 8658 |
-
existing["latest_step"] = step
|
| 8659 |
-
existing["last_updated"] = datetime.now().isoformat()
|
| 8660 |
-
existing["total_checkpoints"] = len(existing["checkpoints"])
|
| 8661 |
-
|
| 8662 |
-
self._hfapi.upload_file(
|
| 8663 |
-
path_or_fileobj=_json.dumps(existing, indent=2).encode(),
|
| 8664 |
-
path_in_repo=self.INDEX_FILENAME,
|
| 8665 |
-
repo_id=self.repo_id,
|
| 8666 |
-
repo_type="dataset",
|
| 8667 |
-
commit_message=f"update index step={step}",
|
| 8668 |
-
)
|
| 8669 |
-
|
| 8670 |
-
def _update_hf_latest(self, step: int) -> None:
|
| 8671 |
-
self._hfapi.upload_file(
|
| 8672 |
-
path_or_fileobj=str(step).encode(),
|
| 8673 |
-
path_in_repo=self.LATEST_SENTINEL,
|
| 8674 |
-
repo_id=self.repo_id,
|
| 8675 |
-
repo_type="dataset",
|
| 8676 |
-
commit_message=f"update latest β step={step}",
|
| 8677 |
-
)
|
| 8678 |
|
| 8679 |
# ββ Download βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 8680 |
|
|
@@ -9157,7 +9168,9 @@ class RankerCheckpointManager:
|
|
| 9157 |
tr.lambda_align = ckpt.get("lambda_align", tr.lambda_align)
|
| 9158 |
tr.rank_margin = ckpt.get("rank_margin", tr.rank_margin)
|
| 9159 |
if "loss_history" in ckpt:
|
| 9160 |
-
|
|
|
|
|
|
|
| 9161 |
logger.info(f" β
trainer restored | train_step={tr.train_step}")
|
| 9162 |
|
| 9163 |
# ββ Replay buffer βββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
@@ -9190,16 +9203,27 @@ class RankerCheckpointManager:
|
|
| 9190 |
else:
|
| 9191 |
logger.warning(f" β οΈ feature_engine[{asset_id}] not in bridge.asset_buffers β skipped")
|
| 9192 |
|
| 9193 |
-
# ββ Runtime counters
|
| 9194 |
bridge.rank_count = ckpt.get("rank_count", bridge.rank_count)
|
| 9195 |
-
|
| 9196 |
-
|
| 9197 |
-
|
| 9198 |
-
|
| 9199 |
-
|
| 9200 |
-
|
| 9201 |
-
|
| 9202 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9203 |
|
| 9204 |
# Post-load validation
|
| 9205 |
assert bridge.rank_count >= 0, "rank_count must be non-negative after restore"
|
|
@@ -9740,7 +9764,7 @@ class QuasarCheckpointManager:
|
|
| 9740 |
self,
|
| 9741 |
ranker, # QUASAR AXRVI Ranker instance
|
| 9742 |
training_steps: Optional[int] = None,
|
| 9743 |
-
device: str = 'cuda',
|
| 9744 |
strict: bool = True,
|
| 9745 |
load_optimizer: bool = True,
|
| 9746 |
load_replay_buffer: bool = True,
|
|
@@ -9876,7 +9900,9 @@ class QuasarCheckpointManager:
|
|
| 9876 |
tr.lambda_align = checkpoint.get("lambda_align", tr.lambda_align)
|
| 9877 |
tr.rank_margin = checkpoint.get("rank_margin", tr.rank_margin)
|
| 9878 |
if "loss_history" in checkpoint:
|
| 9879 |
-
|
|
|
|
|
|
|
| 9880 |
print(f" β
trainer scalars (train_step={tr.train_step})")
|
| 9881 |
|
| 9882 |
# ====================================================================
|
|
@@ -9959,17 +9985,24 @@ class QuasarCheckpointManager:
|
|
| 9959 |
except Exception as e:
|
| 9960 |
print(f" β οΈ feature_engine[{asset_id}]: {str(e)[:50]}")
|
| 9961 |
|
| 9962 |
-
#
|
| 9963 |
-
|
| 9964 |
-
|
| 9965 |
-
|
| 9966 |
-
|
| 9967 |
-
|
| 9968 |
-
if
|
| 9969 |
-
|
| 9970 |
-
|
| 9971 |
-
|
| 9972 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9973 |
if 'rank_count' in checkpoint and hasattr(ranker, 'rank_count'):
|
| 9974 |
ranker.rank_count = checkpoint['rank_count']
|
| 9975 |
print(f" β
rank_count={ranker.rank_count}")
|
|
@@ -10189,7 +10222,7 @@ async def run_live_trading_system(
|
|
| 10189 |
enable_logging: bool = True,
|
| 10190 |
shreve_config: Optional[ShreveConfig] = None,
|
| 10191 |
checkpoint_dir: str = "./Ranker6",
|
| 10192 |
-
resume: bool =
|
| 10193 |
hf_repo_id: Optional[str] = "KarlQuant/k1rl-checkpoints", # HF Dataset repo
|
| 10194 |
) -> None:
|
| 10195 |
config = AssetRankerConfig(
|
|
@@ -10504,8 +10537,8 @@ def _parse_args():
|
|
| 10504 |
help="[S7] Gate E martingale deviation threshold (default 0.05)")
|
| 10505 |
parser.add_argument("--checkpoint-dir", default="./Ranker6",
|
| 10506 |
help="Directory for full-state checkpoints (default ./Ranker6)")
|
| 10507 |
-
parser.add_argument("--resume", action="store_true",
|
| 10508 |
-
help="
|
| 10509 |
parser.add_argument("--hf-repo", default=None,
|
| 10510 |
metavar="OWNER/REPO",
|
| 10511 |
help="Hugging Face Dataset repo for checkpoint sync "
|
|
@@ -10562,7 +10595,7 @@ if __name__ == "__main__":
|
|
| 10562 |
hub_ws_url = args.hub,
|
| 10563 |
enable_logging = not args.no_logs,
|
| 10564 |
checkpoint_dir = args.checkpoint_dir,
|
| 10565 |
-
resume = args.
|
| 10566 |
hf_repo_id = args.hf_repo or "KarlQuant/k1rl-checkpoints",
|
| 10567 |
)
|
| 10568 |
|
|
@@ -10585,7 +10618,7 @@ if __name__ == "__main__":
|
|
| 10585 |
hub_ws_url = args.hub,
|
| 10586 |
enable_logging = not args.no_logs,
|
| 10587 |
checkpoint_dir = args.checkpoint_dir, # FIX 1: was silently ignored
|
| 10588 |
-
resume = args.
|
| 10589 |
hf_repo_id = args.hf_repo or "KarlQuant/k1rl-checkpoints",
|
| 10590 |
))
|
| 10591 |
except KeyboardInterrupt:
|
|
|
|
| 3939 |
"""
|
| 3940 |
Live QCSAM/FABLE cross-asset core β integrates the full quantum pipeline
|
| 3941 |
into the AXRVINet forward pass.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3942 |
|
| 3943 |
This is the component that turns dead QCSAM/FABLE code into the active
|
| 3944 |
cross-asset interaction engine of the ranker.
|
|
|
|
| 3969 |
qffn_output : (B, hilbert_dim) cdouble
|
| 3970 |
readout_input : (B, N, hilbert_dim) float32 (real part broadcast)
|
| 3971 |
layer_output : (B, N, d_model) float32
|
| 3972 |
+
|
| 3973 |
+
Checkpoint notes:
|
| 3974 |
+
All submodules (adapter_proj, qmha, qffn, readout, residual_gate) are
|
| 3975 |
+
registered nn.Module / nn.Parameter attributes and are therefore fully
|
| 3976 |
+
captured by AXRVINet.state_dict(). FABLECLCU inside qmha is eagerly
|
| 3977 |
+
initialised via num_assets so its weights are in state_dict from the
|
| 3978 |
+
very first save. _clcu_n_patches_buf (persistent buffer) guards
|
| 3979 |
+
against accidental re-initialisation on load_state_dict.
|
| 3980 |
"""
|
| 3981 |
|
| 3982 |
# Registry flag β set True after first successful forward.
|
|
|
|
| 4174 |
# directly (e.g. test_components) still get the right architecture.
|
| 4175 |
if feature_dim is not None or seq_len is not None:
|
| 4176 |
import copy
|
| 4177 |
+
config = copy.copy(config) # FIX 8: was cssopy.copy (NameError typo)
|
| 4178 |
if feature_dim is not None:
|
| 4179 |
config.feature_dim = feature_dim
|
| 4180 |
if seq_len is not None:
|
|
|
|
| 5641 |
self.lambda_align= ckpt.get("lambda_align", self.lambda_align) # QCSAM
|
| 5642 |
self.rank_margin = ckpt.get("rank_margin", self.rank_margin)
|
| 5643 |
if "loss_history" in ckpt:
|
| 5644 |
+
# FIX 4c: replace (not extend) to avoid doubling on resume
|
| 5645 |
+
_maxlen = self.loss_history.maxlen
|
| 5646 |
+
self.loss_history = deque(ckpt["loss_history"], maxlen=_maxlen)
|
| 5647 |
logger.info(f"β
Model loaded β {path} | resumed from train_step={self.train_step}")
|
| 5648 |
except FileNotFoundError:
|
| 5649 |
logger.info(f"βΉοΈ No checkpoint at {path} β starting fresh")
|
|
|
|
| 6683 |
hub_ws_url: str = os.environ.get("QUASAR_HUB_URL", "ws://localhost:7860/ws/subscribe"),
|
| 6684 |
enable_logging: bool = True,
|
| 6685 |
checkpoint_dir: str = "./Ranker6", # folder for full-state checkpoints
|
| 6686 |
+
resume: bool = True, # FIX 1c: default True β always resume from latest checkpoint
|
| 6687 |
hf_repo_id: Optional[str] = "KarlQuant/k1rl-checkpoints", # HF Dataset repo
|
| 6688 |
):
|
| 6689 |
self.config = config or AssetRankerConfig()
|
|
|
|
| 8248 |
f"ce={train_result.get('ce', 0):.4f} | "
|
| 8249 |
f"batch_size={len(batch)}"
|
| 8250 |
)
|
| 8251 |
+
# ββ HF SYNC: mirror trainer .pt to HF periodically ββββββββββββ
|
| 8252 |
+
# FIX 6: Upload every 10 training steps (not every step) to avoid
|
| 8253 |
+
# flooding HF with commits. The full-state autosave (maybe_save)
|
| 8254 |
+
# handles persistence every 5 min regardless.
|
| 8255 |
+
if self.trainer.train_step % 10 == 0:
|
| 8256 |
+
self.checkpoint_mgr._hf.queue_upload(
|
| 8257 |
+
local_path=self.config.model_path,
|
| 8258 |
+
step=self.trainer.train_step,
|
| 8259 |
+
metadata={
|
| 8260 |
+
"reason": "train_step",
|
| 8261 |
+
"loss": train_result.get("total", 0.0),
|
| 8262 |
+
"rl": train_result.get("rl", 0.0),
|
| 8263 |
+
"ce": train_result.get("ce", 0.0),
|
| 8264 |
+
"train_step": self.trainer.train_step,
|
| 8265 |
+
},
|
| 8266 |
+
)
|
| 8267 |
else:
|
| 8268 |
logger.warning(
|
| 8269 |
"β οΈ [Training] train_on_batch returned empty result β "
|
|
|
|
| 8582 |
"""
|
| 8583 |
Upload *local_path* to HF Dataset as step_XXXXXXX.pt.
|
| 8584 |
Updates ranker_index.json and latest sentinel on HF.
|
| 8585 |
+
|
| 8586 |
+
FIX 3: All three files (.pt, index, sentinel) are batched into a
|
| 8587 |
+
SINGLE commit via create_commit + CommitOperationAdd, reducing HF
|
| 8588 |
+
commit usage from 3 β 1 per save. This prevents the 256/hr rate
|
| 8589 |
+
limit from being blown by the 11-space fleet.
|
| 8590 |
Returns True on success, False on any error (non-raising).
|
| 8591 |
"""
|
| 8592 |
if not self._ensure_hf():
|
|
|
|
| 8597 |
logger.warning(f"βοΈ HF upload skipped β file not found: {local_path}")
|
| 8598 |
return False
|
| 8599 |
|
| 8600 |
+
import tempfile
|
| 8601 |
+
import json as _json
|
| 8602 |
+
|
| 8603 |
hf_path = self._hf_ckpt_path(step)
|
| 8604 |
try:
|
| 8605 |
if self.verbose:
|
| 8606 |
+
logger.info(f"βοΈ Uploading {hf_path} β {self.repo_id} (batched commit) β¦")
|
| 8607 |
|
| 8608 |
+
from huggingface_hub import CommitOperationAdd
|
| 8609 |
+
|
| 8610 |
+
# ββ Build updated index payload ββββββββββββββββββββββββββββββββββββ
|
| 8611 |
+
existing: Dict[str, Any] = {}
|
| 8612 |
+
try:
|
| 8613 |
+
with tempfile.TemporaryDirectory() as tmpdir:
|
| 8614 |
+
idx_file = self._hf_hub_dl(
|
| 8615 |
+
repo_id=self.repo_id,
|
| 8616 |
+
filename=self.INDEX_FILENAME,
|
| 8617 |
+
repo_type="dataset",
|
| 8618 |
+
token=self.token,
|
| 8619 |
+
local_dir=tmpdir,
|
| 8620 |
+
local_dir_use_symlinks=False,
|
| 8621 |
+
)
|
| 8622 |
+
with open(idx_file) as fh:
|
| 8623 |
+
existing = _json.load(fh)
|
| 8624 |
+
except Exception:
|
| 8625 |
+
pass # no existing index β start fresh
|
| 8626 |
+
|
| 8627 |
+
entry: Dict[str, Any] = {
|
| 8628 |
+
"step": step,
|
| 8629 |
+
"filename": hf_path,
|
| 8630 |
+
"size_mb": round(local_path.stat().st_size / 1_048_576, 2),
|
| 8631 |
+
"timestamp": datetime.now().isoformat(),
|
| 8632 |
+
}
|
| 8633 |
+
if metadata:
|
| 8634 |
+
entry.update(metadata)
|
| 8635 |
|
| 8636 |
+
if "checkpoints" not in existing:
|
| 8637 |
+
existing["checkpoints"] = []
|
| 8638 |
+
existing["checkpoints"] = [
|
| 8639 |
+
cp for cp in existing["checkpoints"] if cp.get("step") != step
|
| 8640 |
+
]
|
| 8641 |
+
existing["checkpoints"].append(entry)
|
| 8642 |
+
existing["checkpoints"].sort(key=lambda x: x.get("step", 0))
|
| 8643 |
+
existing["latest_step"] = step
|
| 8644 |
+
existing["last_updated"] = datetime.now().isoformat()
|
| 8645 |
+
existing["total_checkpoints"] = len(existing["checkpoints"])
|
| 8646 |
+
|
| 8647 |
+
index_bytes = _json.dumps(existing, indent=2).encode()
|
| 8648 |
+
sentinel_bytes = str(step).encode()
|
| 8649 |
+
|
| 8650 |
+
# ββ Single batched commit: .pt + index + sentinel ββββββββββββββββββ
|
| 8651 |
+
with open(local_path, "rb") as pt_fh:
|
| 8652 |
+
pt_bytes = pt_fh.read()
|
| 8653 |
+
|
| 8654 |
+
self._hfapi.create_commit(
|
| 8655 |
+
repo_id=self.repo_id,
|
| 8656 |
+
repo_type="dataset",
|
| 8657 |
+
operations=[
|
| 8658 |
+
CommitOperationAdd(
|
| 8659 |
+
path_in_repo=hf_path,
|
| 8660 |
+
path_or_fileobj=pt_bytes,
|
| 8661 |
+
),
|
| 8662 |
+
CommitOperationAdd(
|
| 8663 |
+
path_in_repo=self.INDEX_FILENAME,
|
| 8664 |
+
path_or_fileobj=index_bytes,
|
| 8665 |
+
),
|
| 8666 |
+
CommitOperationAdd(
|
| 8667 |
+
path_in_repo=self.LATEST_SENTINEL,
|
| 8668 |
+
path_or_fileobj=sentinel_bytes,
|
| 8669 |
+
),
|
| 8670 |
+
],
|
| 8671 |
+
commit_message=f"checkpoint step={step:07d}",
|
| 8672 |
+
)
|
| 8673 |
|
| 8674 |
self._upload_count += 1
|
| 8675 |
self._last_upload_step = step
|
| 8676 |
+
logger.info(
|
| 8677 |
+
f"βοΈ β
HF upload complete (1 commit) | step={step} | repo={self.repo_id}"
|
| 8678 |
+
)
|
| 8679 |
return True
|
| 8680 |
|
| 8681 |
except Exception as exc:
|
| 8682 |
logger.warning(f"βοΈ β οΈ HF upload failed (non-fatal): {exc}")
|
| 8683 |
return False
|
| 8684 |
|
| 8685 |
+
# _update_hf_index and _update_hf_latest removed in FIX 3.
|
| 8686 |
+
# All three files (.pt, index, sentinel) are now batched into a single
|
| 8687 |
+
# create_commit call inside upload() to avoid the 3x commit-per-save
|
| 8688 |
+
# rate-limit explosion.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8689 |
|
| 8690 |
# ββ Download βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 8691 |
|
|
|
|
| 9168 |
tr.lambda_align = ckpt.get("lambda_align", tr.lambda_align)
|
| 9169 |
tr.rank_margin = ckpt.get("rank_margin", tr.rank_margin)
|
| 9170 |
if "loss_history" in ckpt:
|
| 9171 |
+
# FIX 4: replace (not extend) so history isn't doubled on restore
|
| 9172 |
+
_maxlen = tr.loss_history.maxlen
|
| 9173 |
+
tr.loss_history = deque(ckpt["loss_history"], maxlen=_maxlen)
|
| 9174 |
logger.info(f" β
trainer restored | train_step={tr.train_step}")
|
| 9175 |
|
| 9176 |
# ββ Replay buffer βββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
| 9203 |
else:
|
| 9204 |
logger.warning(f" β οΈ feature_engine[{asset_id}] not in bridge.asset_buffers β skipped")
|
| 9205 |
|
| 9206 |
+
# ββ Runtime counters (safe to restore) βββββββββββββββββββββββββββββββ
|
| 9207 |
bridge.rank_count = ckpt.get("rank_count", bridge.rank_count)
|
| 9208 |
+
|
| 9209 |
+
# FIX 5: Do NOT restore pending_episodes or trade_tick_counts.
|
| 9210 |
+
# These reference Deriv multiplier contracts that are gone after a restart.
|
| 9211 |
+
# Restoring them causes phantom open_count > 0, which blocks the top-4
|
| 9212 |
+
# floor enforcer from opening any new trades and corrupts RL episode pairing.
|
| 9213 |
+
# rank_count is the only runtime counter safe to restore across restarts.
|
| 9214 |
+
stale_ep = ckpt.get("pending_episodes", {})
|
| 9215 |
+
stale_tc = ckpt.get("trade_tick_counts", {})
|
| 9216 |
+
if stale_ep:
|
| 9217 |
+
logger.warning(
|
| 9218 |
+
f" β οΈ Discarded {len(stale_ep)} stale pending_episodes from checkpoint "
|
| 9219 |
+
"(Deriv contracts expired across restart β resetting to avoid phantom positions)"
|
| 9220 |
+
)
|
| 9221 |
+
if stale_tc:
|
| 9222 |
+
logger.info(
|
| 9223 |
+
f" βΉοΈ Discarded {len(stale_tc)} stale trade_tick_counts (reset on restart)"
|
| 9224 |
+
)
|
| 9225 |
+
# bridge._pending_episodes and bridge._trade_tick_counts stay at their
|
| 9226 |
+
# freshly-initialised empty-dict state β no update needed.
|
| 9227 |
|
| 9228 |
# Post-load validation
|
| 9229 |
assert bridge.rank_count >= 0, "rank_count must be non-negative after restore"
|
|
|
|
| 9764 |
self,
|
| 9765 |
ranker, # QUASAR AXRVI Ranker instance
|
| 9766 |
training_steps: Optional[int] = None,
|
| 9767 |
+
device: str = 'cpu', # FIX 7: was 'cuda' β HF Spaces are CPU-only, 'cuda' crashes
|
| 9768 |
strict: bool = True,
|
| 9769 |
load_optimizer: bool = True,
|
| 9770 |
load_replay_buffer: bool = True,
|
|
|
|
| 9900 |
tr.lambda_align = checkpoint.get("lambda_align", tr.lambda_align)
|
| 9901 |
tr.rank_margin = checkpoint.get("rank_margin", tr.rank_margin)
|
| 9902 |
if "loss_history" in checkpoint:
|
| 9903 |
+
# FIX 4b: replace (not extend) to avoid doubling history on restore
|
| 9904 |
+
_maxlen = tr.loss_history.maxlen
|
| 9905 |
+
tr.loss_history = deque(checkpoint["loss_history"], maxlen=_maxlen)
|
| 9906 |
print(f" β
trainer scalars (train_step={tr.train_step})")
|
| 9907 |
|
| 9908 |
# ====================================================================
|
|
|
|
| 9985 |
except Exception as e:
|
| 9986 |
print(f" β οΈ feature_engine[{asset_id}]: {str(e)[:50]}")
|
| 9987 |
|
| 9988 |
+
# FIX 5b: Do NOT restore pending_episodes or trade_tick_counts.
|
| 9989 |
+
# These reference Deriv multiplier contracts that are gone after a restart.
|
| 9990 |
+
# Restoring them causes phantom open_count > 0, blocking the top-4 floor
|
| 9991 |
+
# enforcer and corrupting RL episode pairing with dead contract IDs.
|
| 9992 |
+
stale_ep = checkpoint.get('pending_episodes', {})
|
| 9993 |
+
stale_tc = checkpoint.get('trade_tick_counts', {})
|
| 9994 |
+
if stale_ep:
|
| 9995 |
+
print(
|
| 9996 |
+
f" β οΈ Discarded {len(stale_ep)} stale pending_episodes "
|
| 9997 |
+
"(Deriv contracts expired across restart β phantom positions prevented)"
|
| 9998 |
+
)
|
| 9999 |
+
if stale_tc:
|
| 10000 |
+
print(
|
| 10001 |
+
f" βΉοΈ Discarded {len(stale_tc)} stale trade_tick_counts (reset on restart)"
|
| 10002 |
+
)
|
| 10003 |
+
# _pending_episodes and _trade_tick_counts stay at fresh empty-dict state.
|
| 10004 |
+
|
| 10005 |
+
# Runtime counters (safe to restore)
|
| 10006 |
if 'rank_count' in checkpoint and hasattr(ranker, 'rank_count'):
|
| 10007 |
ranker.rank_count = checkpoint['rank_count']
|
| 10008 |
print(f" β
rank_count={ranker.rank_count}")
|
|
|
|
| 10222 |
enable_logging: bool = True,
|
| 10223 |
shreve_config: Optional[ShreveConfig] = None,
|
| 10224 |
checkpoint_dir: str = "./Ranker6",
|
| 10225 |
+
resume: bool = True, # FIX 2: default True β always resume from checkpoint
|
| 10226 |
hf_repo_id: Optional[str] = "KarlQuant/k1rl-checkpoints", # HF Dataset repo
|
| 10227 |
) -> None:
|
| 10228 |
config = AssetRankerConfig(
|
|
|
|
| 10537 |
help="[S7] Gate E martingale deviation threshold (default 0.05)")
|
| 10538 |
parser.add_argument("--checkpoint-dir", default="./Ranker6",
|
| 10539 |
help="Directory for full-state checkpoints (default ./Ranker6)")
|
| 10540 |
+
parser.add_argument("--no-resume", dest="no_resume", action="store_true",
|
| 10541 |
+
help="Disable checkpoint resume on startup (default: always resume from latest checkpoint)")
|
| 10542 |
parser.add_argument("--hf-repo", default=None,
|
| 10543 |
metavar="OWNER/REPO",
|
| 10544 |
help="Hugging Face Dataset repo for checkpoint sync "
|
|
|
|
| 10595 |
hub_ws_url = args.hub,
|
| 10596 |
enable_logging = not args.no_logs,
|
| 10597 |
checkpoint_dir = args.checkpoint_dir,
|
| 10598 |
+
resume = not args.no_resume, # FIX 1: default True (always resume)
|
| 10599 |
hf_repo_id = args.hf_repo or "KarlQuant/k1rl-checkpoints",
|
| 10600 |
)
|
| 10601 |
|
|
|
|
| 10618 |
hub_ws_url = args.hub,
|
| 10619 |
enable_logging = not args.no_logs,
|
| 10620 |
checkpoint_dir = args.checkpoint_dir, # FIX 1: was silently ignored
|
| 10621 |
+
resume = not args.no_resume, # FIX 2: default True (always resume)
|
| 10622 |
hf_repo_id = args.hf_repo or "KarlQuant/k1rl-checkpoints",
|
| 10623 |
))
|
| 10624 |
except KeyboardInterrupt:
|