# src/utils/aop_utils.py import os, json from contextlib import contextmanager def _parse_bool(v: str, default=False): if v is None: return default v = v.strip().lower() return v in {"1","true","yes","y","t","on"} def _parse_float(v: str, default=None): try: return float(v) if v is not None else default except: return default def _parse_int(v: str, default=None): try: return int(v) if v is not None else default except: return default def get_env_aop_config(): enabled = _parse_bool(os.environ.get("AOP_ENABLED"), False) apply_to = os.environ.get("AOP_APPLY", "qry").strip().lower() # qry|cand|both layer_idx = _parse_int(os.environ.get("AOP_LAYER"), None) mode = os.environ.get("AOP_MODE", "delta").strip().lower() delta = _parse_float(os.environ.get("AOP_DELTA"), 0.10) khat = _parse_float(os.environ.get("AOP_KHAT"), 1.0) keep_ratio = _parse_float(os.environ.get("AOP_KEEP_RATIO"), 1.0) min_keep = _parse_int(os.environ.get("AOP_MIN_KEEP"), 64) use_bias = _parse_bool(os.environ.get("AOP_USE_BIAS"), True) prune_vision = _parse_bool(os.environ.get("AOP_PRUNE_VISION"), True) prune_text = _parse_bool(os.environ.get("AOP_PRUNE_TEXT"), False) delta_v = _parse_float(os.environ.get("AOP_DELTA_VISION"), None) khat_v = _parse_float(os.environ.get("AOP_KHAT_VISION"), None) keep_ratio_v= _parse_float(os.environ.get("AOP_KEEP_RATIO_VISION"), None) min_keep_v = _parse_int(os.environ.get("AOP_MIN_KEEP_VISION"), None) delta_t = _parse_float(os.environ.get("AOP_DELTA_TEXT"), None) khat_t = _parse_float(os.environ.get("AOP_KHAT_TEXT"), None) keep_ratio_t= _parse_float(os.environ.get("AOP_KEEP_RATIO_TEXT"), None) min_keep_t = _parse_int(os.environ.get("AOP_MIN_KEEP_TEXT"), 32) protect_text_last = _parse_int(os.environ.get("AOP_PROTECT_TEXT_LAST"), 16) protect_special = _parse_bool(os.environ.get("AOP_PROTECT_SPECIAL"), True) margin_src = os.environ.get("AOP_MARGIN", "").strip().lower() attn_impl = os.environ.get("AOP_ATTN_IMPL", "").strip().lower() # "" | "sdpa" selection = os.environ.get("AOP_SELECTION", "aop").strip().lower() # aop|attention|random if _parse_bool(os.environ.get("AOP_RANDOM"), False): selection = "random" random_seed = _parse_int(os.environ.get("AOP_RANDOM_SEED"), None) attn_agg = os.environ.get("AOP_ATTENTION_AGG", "mean").strip().lower() # mean|max|sum if layer_idx is None and enabled: enabled = False return { "enabled": enabled, "apply_to": apply_to, "layer_idx": layer_idx, "mode": mode, "delta": delta, "K_hat": khat, "keep_ratio": keep_ratio, "min_keep": min_keep, "use_bias": use_bias, "eps": 1e-6, "prune_vision": prune_vision, "prune_text": prune_text, "delta_vision": delta_v, "K_hat_vision": khat_v, "keep_ratio_vision": keep_ratio_v, "min_keep_vision": min_keep_v, "delta_text": delta_t, "K_hat_text": khat_t, "keep_ratio_text": keep_ratio_t, "min_keep_text": min_keep_t, "protect_text_last": protect_text_last, "protect_special": protect_special, "margin_mid": None if margin_src != "mid" else "USE_MID_MARGIN", "epsilon_hat": None, "attn_impl_override": attn_impl if attn_impl in {"sdpa"} else "", "selection": selection, "random_seed": random_seed, "attn_agg": attn_agg, } def apply_aop_to_model(model): """ 注入 AOP 配置到底模:model.encoder.aop_prune_config 如需 attention 重要性,支持 AOP_ATTN_IMPL=sdpa(因为 flash_attn2 不输出 attn) """ aop_cfg = get_env_aop_config() if not aop_cfg["enabled"]: print("[AOP] disabled") return aop_cfg setattr(model.encoder, "aop_prune_config", aop_cfg) attn_override = aop_cfg.get("attn_impl_override", "") if attn_override: try: if hasattr(model.encoder, "model") and hasattr(model.encoder.model, "config"): prev = model.encoder.model.config._attn_implementation model.encoder.model.config._attn_implementation = attn_override print(f"[AOP] override attn impl: {prev} -> {attn_override}") except Exception as e: print(f"[AOP] try override attn impl failed: {e}") print("[AOP] config:", json.dumps({ "apply_to": aop_cfg["apply_to"], "layer_idx": aop_cfg["layer_idx"], "mode": aop_cfg["mode"], "prune_text": aop_cfg.get("prune_text", False), "keep_ratio_text": aop_cfg.get("keep_ratio_text", None), "keep_ratio_vision": aop_cfg.get("keep_ratio_vision", None), "selection": aop_cfg.get("selection", "aop"), "attn_agg": aop_cfg.get("attn_agg", "mean"), })) return aop_cfg @contextmanager def aop_side(model, side: str): """ with aop_side(model, "qry"): 仅在 qry 侧启用(若 AOP_APPLY 包含该侧),退出自动恢复 """ aop_cfg = getattr(model.encoder, "aop_prune_config", None) prev_enabled = None if isinstance(aop_cfg, dict) and aop_cfg: prev_enabled = aop_cfg.get("enabled", False) apply_to = aop_cfg.get("apply_to", "qry") side_enable = (apply_to == "both") or (apply_to == side) aop_cfg["enabled"] = bool(side_enable and prev_enabled) setattr(model.encoder, "aop_prune_config", aop_cfg) try: yield finally: if isinstance(aop_cfg, dict) and prev_enabled is not None: aop_cfg["enabled"] = prev_enabled setattr(model.encoder, "aop_prune_config", aop_cfg)