| |
| import os, json |
| from contextlib import contextmanager |
|
|
| def _parse_bool(v: str, default=False): |
| if v is None: return default |
| v = v.strip().lower() |
| return v in {"1","true","yes","y","t","on"} |
|
|
| def _parse_float(v: str, default=None): |
| try: return float(v) if v is not None else default |
| except: return default |
|
|
| def _parse_int(v: str, default=None): |
| try: return int(v) if v is not None else default |
| except: return default |
|
|
| def get_env_aop_config(): |
| enabled = _parse_bool(os.environ.get("AOP_ENABLED"), False) |
| apply_to = os.environ.get("AOP_APPLY", "qry").strip().lower() |
| layer_idx = _parse_int(os.environ.get("AOP_LAYER"), None) |
| mode = os.environ.get("AOP_MODE", "delta").strip().lower() |
|
|
| delta = _parse_float(os.environ.get("AOP_DELTA"), 0.10) |
| khat = _parse_float(os.environ.get("AOP_KHAT"), 1.0) |
| keep_ratio = _parse_float(os.environ.get("AOP_KEEP_RATIO"), 1.0) |
| min_keep = _parse_int(os.environ.get("AOP_MIN_KEEP"), 64) |
| use_bias = _parse_bool(os.environ.get("AOP_USE_BIAS"), True) |
|
|
| prune_vision = _parse_bool(os.environ.get("AOP_PRUNE_VISION"), True) |
| prune_text = _parse_bool(os.environ.get("AOP_PRUNE_TEXT"), False) |
|
|
| delta_v = _parse_float(os.environ.get("AOP_DELTA_VISION"), None) |
| khat_v = _parse_float(os.environ.get("AOP_KHAT_VISION"), None) |
| keep_ratio_v= _parse_float(os.environ.get("AOP_KEEP_RATIO_VISION"), None) |
| min_keep_v = _parse_int(os.environ.get("AOP_MIN_KEEP_VISION"), None) |
|
|
| delta_t = _parse_float(os.environ.get("AOP_DELTA_TEXT"), None) |
| khat_t = _parse_float(os.environ.get("AOP_KHAT_TEXT"), None) |
| keep_ratio_t= _parse_float(os.environ.get("AOP_KEEP_RATIO_TEXT"), None) |
| min_keep_t = _parse_int(os.environ.get("AOP_MIN_KEEP_TEXT"), 32) |
|
|
| protect_text_last = _parse_int(os.environ.get("AOP_PROTECT_TEXT_LAST"), 16) |
| protect_special = _parse_bool(os.environ.get("AOP_PROTECT_SPECIAL"), True) |
|
|
| margin_src = os.environ.get("AOP_MARGIN", "").strip().lower() |
| attn_impl = os.environ.get("AOP_ATTN_IMPL", "").strip().lower() |
|
|
| selection = os.environ.get("AOP_SELECTION", "aop").strip().lower() |
| if _parse_bool(os.environ.get("AOP_RANDOM"), False): |
| selection = "random" |
| random_seed = _parse_int(os.environ.get("AOP_RANDOM_SEED"), None) |
| attn_agg = os.environ.get("AOP_ATTENTION_AGG", "mean").strip().lower() |
|
|
| if layer_idx is None and enabled: |
| enabled = False |
|
|
| return { |
| "enabled": enabled, |
| "apply_to": apply_to, |
| "layer_idx": layer_idx, |
| "mode": mode, |
|
|
| "delta": delta, "K_hat": khat, |
| "keep_ratio": keep_ratio, "min_keep": min_keep, |
| "use_bias": use_bias, "eps": 1e-6, |
|
|
| "prune_vision": prune_vision, |
| "prune_text": prune_text, |
|
|
| "delta_vision": delta_v, |
| "K_hat_vision": khat_v, |
| "keep_ratio_vision": keep_ratio_v, |
| "min_keep_vision": min_keep_v, |
|
|
| "delta_text": delta_t, |
| "K_hat_text": khat_t, |
| "keep_ratio_text": keep_ratio_t, |
| "min_keep_text": min_keep_t, |
|
|
| "protect_text_last": protect_text_last, |
| "protect_special": protect_special, |
|
|
| "margin_mid": None if margin_src != "mid" else "USE_MID_MARGIN", |
| "epsilon_hat": None, |
| "attn_impl_override": attn_impl if attn_impl in {"sdpa"} else "", |
|
|
| "selection": selection, |
| "random_seed": random_seed, |
| "attn_agg": attn_agg, |
| } |
|
|
| def apply_aop_to_model(model): |
| """ |
| 注入 AOP 配置到底模:model.encoder.aop_prune_config |
| 如需 attention 重要性,支持 AOP_ATTN_IMPL=sdpa(因为 flash_attn2 不输出 attn) |
| """ |
| aop_cfg = get_env_aop_config() |
| if not aop_cfg["enabled"]: |
| print("[AOP] disabled") |
| return aop_cfg |
|
|
| setattr(model.encoder, "aop_prune_config", aop_cfg) |
| attn_override = aop_cfg.get("attn_impl_override", "") |
| if attn_override: |
| try: |
| if hasattr(model.encoder, "model") and hasattr(model.encoder.model, "config"): |
| prev = model.encoder.model.config._attn_implementation |
| model.encoder.model.config._attn_implementation = attn_override |
| print(f"[AOP] override attn impl: {prev} -> {attn_override}") |
| except Exception as e: |
| print(f"[AOP] try override attn impl failed: {e}") |
| print("[AOP] config:", json.dumps({ |
| "apply_to": aop_cfg["apply_to"], "layer_idx": aop_cfg["layer_idx"], "mode": aop_cfg["mode"], |
| "prune_text": aop_cfg.get("prune_text", False), |
| "keep_ratio_text": aop_cfg.get("keep_ratio_text", None), |
| "keep_ratio_vision": aop_cfg.get("keep_ratio_vision", None), |
| "selection": aop_cfg.get("selection", "aop"), |
| "attn_agg": aop_cfg.get("attn_agg", "mean"), |
| })) |
| return aop_cfg |
|
|
| @contextmanager |
| def aop_side(model, side: str): |
| """ |
| with aop_side(model, "qry"): 仅在 qry 侧启用(若 AOP_APPLY 包含该侧),退出自动恢复 |
| """ |
| aop_cfg = getattr(model.encoder, "aop_prune_config", None) |
| prev_enabled = None |
| if isinstance(aop_cfg, dict) and aop_cfg: |
| prev_enabled = aop_cfg.get("enabled", False) |
| apply_to = aop_cfg.get("apply_to", "qry") |
| side_enable = (apply_to == "both") or (apply_to == side) |
| aop_cfg["enabled"] = bool(side_enable and prev_enabled) |
| setattr(model.encoder, "aop_prune_config", aop_cfg) |
| try: |
| yield |
| finally: |
| if isinstance(aop_cfg, dict) and prev_enabled is not None: |
| aop_cfg["enabled"] = prev_enabled |
| setattr(model.encoder, "aop_prune_config", aop_cfg) |