File size: 5,743 Bytes
0a937d7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 | # src/utils/aop_utils.py
import os, json
from contextlib import contextmanager
def _parse_bool(v: str, default=False):
if v is None: return default
v = v.strip().lower()
return v in {"1","true","yes","y","t","on"}
def _parse_float(v: str, default=None):
try: return float(v) if v is not None else default
except: return default
def _parse_int(v: str, default=None):
try: return int(v) if v is not None else default
except: return default
def get_env_aop_config():
enabled = _parse_bool(os.environ.get("AOP_ENABLED"), False)
apply_to = os.environ.get("AOP_APPLY", "qry").strip().lower() # qry|cand|both
layer_idx = _parse_int(os.environ.get("AOP_LAYER"), None)
mode = os.environ.get("AOP_MODE", "delta").strip().lower()
delta = _parse_float(os.environ.get("AOP_DELTA"), 0.10)
khat = _parse_float(os.environ.get("AOP_KHAT"), 1.0)
keep_ratio = _parse_float(os.environ.get("AOP_KEEP_RATIO"), 1.0)
min_keep = _parse_int(os.environ.get("AOP_MIN_KEEP"), 64)
use_bias = _parse_bool(os.environ.get("AOP_USE_BIAS"), True)
prune_vision = _parse_bool(os.environ.get("AOP_PRUNE_VISION"), True)
prune_text = _parse_bool(os.environ.get("AOP_PRUNE_TEXT"), False)
delta_v = _parse_float(os.environ.get("AOP_DELTA_VISION"), None)
khat_v = _parse_float(os.environ.get("AOP_KHAT_VISION"), None)
keep_ratio_v= _parse_float(os.environ.get("AOP_KEEP_RATIO_VISION"), None)
min_keep_v = _parse_int(os.environ.get("AOP_MIN_KEEP_VISION"), None)
delta_t = _parse_float(os.environ.get("AOP_DELTA_TEXT"), None)
khat_t = _parse_float(os.environ.get("AOP_KHAT_TEXT"), None)
keep_ratio_t= _parse_float(os.environ.get("AOP_KEEP_RATIO_TEXT"), None)
min_keep_t = _parse_int(os.environ.get("AOP_MIN_KEEP_TEXT"), 32)
protect_text_last = _parse_int(os.environ.get("AOP_PROTECT_TEXT_LAST"), 16)
protect_special = _parse_bool(os.environ.get("AOP_PROTECT_SPECIAL"), True)
margin_src = os.environ.get("AOP_MARGIN", "").strip().lower()
attn_impl = os.environ.get("AOP_ATTN_IMPL", "").strip().lower() # "" | "sdpa"
selection = os.environ.get("AOP_SELECTION", "aop").strip().lower() # aop|attention|random
if _parse_bool(os.environ.get("AOP_RANDOM"), False):
selection = "random"
random_seed = _parse_int(os.environ.get("AOP_RANDOM_SEED"), None)
attn_agg = os.environ.get("AOP_ATTENTION_AGG", "mean").strip().lower() # mean|max|sum
if layer_idx is None and enabled:
enabled = False
return {
"enabled": enabled,
"apply_to": apply_to,
"layer_idx": layer_idx,
"mode": mode,
"delta": delta, "K_hat": khat,
"keep_ratio": keep_ratio, "min_keep": min_keep,
"use_bias": use_bias, "eps": 1e-6,
"prune_vision": prune_vision,
"prune_text": prune_text,
"delta_vision": delta_v,
"K_hat_vision": khat_v,
"keep_ratio_vision": keep_ratio_v,
"min_keep_vision": min_keep_v,
"delta_text": delta_t,
"K_hat_text": khat_t,
"keep_ratio_text": keep_ratio_t,
"min_keep_text": min_keep_t,
"protect_text_last": protect_text_last,
"protect_special": protect_special,
"margin_mid": None if margin_src != "mid" else "USE_MID_MARGIN",
"epsilon_hat": None,
"attn_impl_override": attn_impl if attn_impl in {"sdpa"} else "",
"selection": selection,
"random_seed": random_seed,
"attn_agg": attn_agg,
}
def apply_aop_to_model(model):
"""
注入 AOP 配置到底模:model.encoder.aop_prune_config
如需 attention 重要性,支持 AOP_ATTN_IMPL=sdpa(因为 flash_attn2 不输出 attn)
"""
aop_cfg = get_env_aop_config()
if not aop_cfg["enabled"]:
print("[AOP] disabled")
return aop_cfg
setattr(model.encoder, "aop_prune_config", aop_cfg)
attn_override = aop_cfg.get("attn_impl_override", "")
if attn_override:
try:
if hasattr(model.encoder, "model") and hasattr(model.encoder.model, "config"):
prev = model.encoder.model.config._attn_implementation
model.encoder.model.config._attn_implementation = attn_override
print(f"[AOP] override attn impl: {prev} -> {attn_override}")
except Exception as e:
print(f"[AOP] try override attn impl failed: {e}")
print("[AOP] config:", json.dumps({
"apply_to": aop_cfg["apply_to"], "layer_idx": aop_cfg["layer_idx"], "mode": aop_cfg["mode"],
"prune_text": aop_cfg.get("prune_text", False),
"keep_ratio_text": aop_cfg.get("keep_ratio_text", None),
"keep_ratio_vision": aop_cfg.get("keep_ratio_vision", None),
"selection": aop_cfg.get("selection", "aop"),
"attn_agg": aop_cfg.get("attn_agg", "mean"),
}))
return aop_cfg
@contextmanager
def aop_side(model, side: str):
"""
with aop_side(model, "qry"): 仅在 qry 侧启用(若 AOP_APPLY 包含该侧),退出自动恢复
"""
aop_cfg = getattr(model.encoder, "aop_prune_config", None)
prev_enabled = None
if isinstance(aop_cfg, dict) and aop_cfg:
prev_enabled = aop_cfg.get("enabled", False)
apply_to = aop_cfg.get("apply_to", "qry")
side_enable = (apply_to == "both") or (apply_to == side)
aop_cfg["enabled"] = bool(side_enable and prev_enabled)
setattr(model.encoder, "aop_prune_config", aop_cfg)
try:
yield
finally:
if isinstance(aop_cfg, dict) and prev_enabled is not None:
aop_cfg["enabled"] = prev_enabled
setattr(model.encoder, "aop_prune_config", aop_cfg) |