|
|
|
|
|
import os, json |
|
|
from contextlib import contextmanager |
|
|
|
|
|
def _parse_bool(v: str, default=False): |
|
|
if v is None: return default |
|
|
v = v.strip().lower() |
|
|
return v in {"1","true","yes","y","t","on"} |
|
|
|
|
|
def _parse_float(v: str, default=None): |
|
|
try: return float(v) if v is not None else default |
|
|
except: return default |
|
|
|
|
|
def _parse_int(v: str, default=None): |
|
|
try: return int(v) if v is not None else default |
|
|
except: return default |
|
|
|
|
|
def get_env_aop_config(): |
|
|
enabled = _parse_bool(os.environ.get("AOP_ENABLED"), False) |
|
|
apply_to = os.environ.get("AOP_APPLY", "qry").strip().lower() |
|
|
layer_idx = _parse_int(os.environ.get("AOP_LAYER"), None) |
|
|
mode = os.environ.get("AOP_MODE", "delta").strip().lower() |
|
|
|
|
|
delta = _parse_float(os.environ.get("AOP_DELTA"), 0.10) |
|
|
khat = _parse_float(os.environ.get("AOP_KHAT"), 1.0) |
|
|
keep_ratio = _parse_float(os.environ.get("AOP_KEEP_RATIO"), 1.0) |
|
|
min_keep = _parse_int(os.environ.get("AOP_MIN_KEEP"), 64) |
|
|
use_bias = _parse_bool(os.environ.get("AOP_USE_BIAS"), True) |
|
|
|
|
|
prune_vision = _parse_bool(os.environ.get("AOP_PRUNE_VISION"), True) |
|
|
prune_text = _parse_bool(os.environ.get("AOP_PRUNE_TEXT"), False) |
|
|
|
|
|
delta_v = _parse_float(os.environ.get("AOP_DELTA_VISION"), None) |
|
|
khat_v = _parse_float(os.environ.get("AOP_KHAT_VISION"), None) |
|
|
keep_ratio_v= _parse_float(os.environ.get("AOP_KEEP_RATIO_VISION"), None) |
|
|
min_keep_v = _parse_int(os.environ.get("AOP_MIN_KEEP_VISION"), None) |
|
|
|
|
|
delta_t = _parse_float(os.environ.get("AOP_DELTA_TEXT"), None) |
|
|
khat_t = _parse_float(os.environ.get("AOP_KHAT_TEXT"), None) |
|
|
keep_ratio_t= _parse_float(os.environ.get("AOP_KEEP_RATIO_TEXT"), None) |
|
|
min_keep_t = _parse_int(os.environ.get("AOP_MIN_KEEP_TEXT"), 32) |
|
|
|
|
|
protect_text_last = _parse_int(os.environ.get("AOP_PROTECT_TEXT_LAST"), 16) |
|
|
protect_special = _parse_bool(os.environ.get("AOP_PROTECT_SPECIAL"), True) |
|
|
|
|
|
margin_src = os.environ.get("AOP_MARGIN", "").strip().lower() |
|
|
attn_impl = os.environ.get("AOP_ATTN_IMPL", "").strip().lower() |
|
|
|
|
|
selection = os.environ.get("AOP_SELECTION", "aop").strip().lower() |
|
|
if _parse_bool(os.environ.get("AOP_RANDOM"), False): |
|
|
selection = "random" |
|
|
random_seed = _parse_int(os.environ.get("AOP_RANDOM_SEED"), None) |
|
|
attn_agg = os.environ.get("AOP_ATTENTION_AGG", "mean").strip().lower() |
|
|
|
|
|
if layer_idx is None and enabled: |
|
|
enabled = False |
|
|
|
|
|
return { |
|
|
"enabled": enabled, |
|
|
"apply_to": apply_to, |
|
|
"layer_idx": layer_idx, |
|
|
"mode": mode, |
|
|
|
|
|
"delta": delta, "K_hat": khat, |
|
|
"keep_ratio": keep_ratio, "min_keep": min_keep, |
|
|
"use_bias": use_bias, "eps": 1e-6, |
|
|
|
|
|
"prune_vision": prune_vision, |
|
|
"prune_text": prune_text, |
|
|
|
|
|
"delta_vision": delta_v, |
|
|
"K_hat_vision": khat_v, |
|
|
"keep_ratio_vision": keep_ratio_v, |
|
|
"min_keep_vision": min_keep_v, |
|
|
|
|
|
"delta_text": delta_t, |
|
|
"K_hat_text": khat_t, |
|
|
"keep_ratio_text": keep_ratio_t, |
|
|
"min_keep_text": min_keep_t, |
|
|
|
|
|
"protect_text_last": protect_text_last, |
|
|
"protect_special": protect_special, |
|
|
|
|
|
"margin_mid": None if margin_src != "mid" else "USE_MID_MARGIN", |
|
|
"epsilon_hat": None, |
|
|
"attn_impl_override": attn_impl if attn_impl in {"sdpa"} else "", |
|
|
|
|
|
"selection": selection, |
|
|
"random_seed": random_seed, |
|
|
"attn_agg": attn_agg, |
|
|
} |
|
|
|
|
|
def apply_aop_to_model(model): |
|
|
""" |
|
|
注入 AOP 配置到底模:model.encoder.aop_prune_config |
|
|
如需 attention 重要性,支持 AOP_ATTN_IMPL=sdpa(因为 flash_attn2 不输出 attn) |
|
|
""" |
|
|
aop_cfg = get_env_aop_config() |
|
|
if not aop_cfg["enabled"]: |
|
|
print("[AOP] disabled") |
|
|
return aop_cfg |
|
|
|
|
|
setattr(model.encoder, "aop_prune_config", aop_cfg) |
|
|
attn_override = aop_cfg.get("attn_impl_override", "") |
|
|
if attn_override: |
|
|
try: |
|
|
if hasattr(model.encoder, "model") and hasattr(model.encoder.model, "config"): |
|
|
prev = model.encoder.model.config._attn_implementation |
|
|
model.encoder.model.config._attn_implementation = attn_override |
|
|
print(f"[AOP] override attn impl: {prev} -> {attn_override}") |
|
|
except Exception as e: |
|
|
print(f"[AOP] try override attn impl failed: {e}") |
|
|
print("[AOP] config:", json.dumps({ |
|
|
"apply_to": aop_cfg["apply_to"], "layer_idx": aop_cfg["layer_idx"], "mode": aop_cfg["mode"], |
|
|
"prune_text": aop_cfg.get("prune_text", False), |
|
|
"keep_ratio_text": aop_cfg.get("keep_ratio_text", None), |
|
|
"keep_ratio_vision": aop_cfg.get("keep_ratio_vision", None), |
|
|
"selection": aop_cfg.get("selection", "aop"), |
|
|
"attn_agg": aop_cfg.get("attn_agg", "mean"), |
|
|
})) |
|
|
return aop_cfg |
|
|
|
|
|
@contextmanager |
|
|
def aop_side(model, side: str): |
|
|
""" |
|
|
with aop_side(model, "qry"): 仅在 qry 侧启用(若 AOP_APPLY 包含该侧),退出自动恢复 |
|
|
""" |
|
|
aop_cfg = getattr(model.encoder, "aop_prune_config", None) |
|
|
prev_enabled = None |
|
|
if isinstance(aop_cfg, dict) and aop_cfg: |
|
|
prev_enabled = aop_cfg.get("enabled", False) |
|
|
apply_to = aop_cfg.get("apply_to", "qry") |
|
|
side_enable = (apply_to == "both") or (apply_to == side) |
|
|
aop_cfg["enabled"] = bool(side_enable and prev_enabled) |
|
|
setattr(model.encoder, "aop_prune_config", aop_cfg) |
|
|
try: |
|
|
yield |
|
|
finally: |
|
|
if isinstance(aop_cfg, dict) and prev_enabled is not None: |
|
|
aop_cfg["enabled"] = prev_enabled |
|
|
setattr(model.encoder, "aop_prune_config", aop_cfg) |