code_SAS_VLM2Vec / src /aop_utils.py
MgGladys's picture
Add files using upload-large-folder tool
0a937d7 verified
# src/utils/aop_utils.py
import os, json
from contextlib import contextmanager
def _parse_bool(v: str, default=False):
if v is None: return default
v = v.strip().lower()
return v in {"1","true","yes","y","t","on"}
def _parse_float(v: str, default=None):
try: return float(v) if v is not None else default
except: return default
def _parse_int(v: str, default=None):
try: return int(v) if v is not None else default
except: return default
def get_env_aop_config():
enabled = _parse_bool(os.environ.get("AOP_ENABLED"), False)
apply_to = os.environ.get("AOP_APPLY", "qry").strip().lower() # qry|cand|both
layer_idx = _parse_int(os.environ.get("AOP_LAYER"), None)
mode = os.environ.get("AOP_MODE", "delta").strip().lower()
delta = _parse_float(os.environ.get("AOP_DELTA"), 0.10)
khat = _parse_float(os.environ.get("AOP_KHAT"), 1.0)
keep_ratio = _parse_float(os.environ.get("AOP_KEEP_RATIO"), 1.0)
min_keep = _parse_int(os.environ.get("AOP_MIN_KEEP"), 64)
use_bias = _parse_bool(os.environ.get("AOP_USE_BIAS"), True)
prune_vision = _parse_bool(os.environ.get("AOP_PRUNE_VISION"), True)
prune_text = _parse_bool(os.environ.get("AOP_PRUNE_TEXT"), False)
delta_v = _parse_float(os.environ.get("AOP_DELTA_VISION"), None)
khat_v = _parse_float(os.environ.get("AOP_KHAT_VISION"), None)
keep_ratio_v= _parse_float(os.environ.get("AOP_KEEP_RATIO_VISION"), None)
min_keep_v = _parse_int(os.environ.get("AOP_MIN_KEEP_VISION"), None)
delta_t = _parse_float(os.environ.get("AOP_DELTA_TEXT"), None)
khat_t = _parse_float(os.environ.get("AOP_KHAT_TEXT"), None)
keep_ratio_t= _parse_float(os.environ.get("AOP_KEEP_RATIO_TEXT"), None)
min_keep_t = _parse_int(os.environ.get("AOP_MIN_KEEP_TEXT"), 32)
protect_text_last = _parse_int(os.environ.get("AOP_PROTECT_TEXT_LAST"), 16)
protect_special = _parse_bool(os.environ.get("AOP_PROTECT_SPECIAL"), True)
margin_src = os.environ.get("AOP_MARGIN", "").strip().lower()
attn_impl = os.environ.get("AOP_ATTN_IMPL", "").strip().lower() # "" | "sdpa"
selection = os.environ.get("AOP_SELECTION", "aop").strip().lower() # aop|attention|random
if _parse_bool(os.environ.get("AOP_RANDOM"), False):
selection = "random"
random_seed = _parse_int(os.environ.get("AOP_RANDOM_SEED"), None)
attn_agg = os.environ.get("AOP_ATTENTION_AGG", "mean").strip().lower() # mean|max|sum
if layer_idx is None and enabled:
enabled = False
return {
"enabled": enabled,
"apply_to": apply_to,
"layer_idx": layer_idx,
"mode": mode,
"delta": delta, "K_hat": khat,
"keep_ratio": keep_ratio, "min_keep": min_keep,
"use_bias": use_bias, "eps": 1e-6,
"prune_vision": prune_vision,
"prune_text": prune_text,
"delta_vision": delta_v,
"K_hat_vision": khat_v,
"keep_ratio_vision": keep_ratio_v,
"min_keep_vision": min_keep_v,
"delta_text": delta_t,
"K_hat_text": khat_t,
"keep_ratio_text": keep_ratio_t,
"min_keep_text": min_keep_t,
"protect_text_last": protect_text_last,
"protect_special": protect_special,
"margin_mid": None if margin_src != "mid" else "USE_MID_MARGIN",
"epsilon_hat": None,
"attn_impl_override": attn_impl if attn_impl in {"sdpa"} else "",
"selection": selection,
"random_seed": random_seed,
"attn_agg": attn_agg,
}
def apply_aop_to_model(model):
"""
注入 AOP 配置到底模:model.encoder.aop_prune_config
如需 attention 重要性,支持 AOP_ATTN_IMPL=sdpa(因为 flash_attn2 不输出 attn)
"""
aop_cfg = get_env_aop_config()
if not aop_cfg["enabled"]:
print("[AOP] disabled")
return aop_cfg
setattr(model.encoder, "aop_prune_config", aop_cfg)
attn_override = aop_cfg.get("attn_impl_override", "")
if attn_override:
try:
if hasattr(model.encoder, "model") and hasattr(model.encoder.model, "config"):
prev = model.encoder.model.config._attn_implementation
model.encoder.model.config._attn_implementation = attn_override
print(f"[AOP] override attn impl: {prev} -> {attn_override}")
except Exception as e:
print(f"[AOP] try override attn impl failed: {e}")
print("[AOP] config:", json.dumps({
"apply_to": aop_cfg["apply_to"], "layer_idx": aop_cfg["layer_idx"], "mode": aop_cfg["mode"],
"prune_text": aop_cfg.get("prune_text", False),
"keep_ratio_text": aop_cfg.get("keep_ratio_text", None),
"keep_ratio_vision": aop_cfg.get("keep_ratio_vision", None),
"selection": aop_cfg.get("selection", "aop"),
"attn_agg": aop_cfg.get("attn_agg", "mean"),
}))
return aop_cfg
@contextmanager
def aop_side(model, side: str):
"""
with aop_side(model, "qry"): 仅在 qry 侧启用(若 AOP_APPLY 包含该侧),退出自动恢复
"""
aop_cfg = getattr(model.encoder, "aop_prune_config", None)
prev_enabled = None
if isinstance(aop_cfg, dict) and aop_cfg:
prev_enabled = aop_cfg.get("enabled", False)
apply_to = aop_cfg.get("apply_to", "qry")
side_enable = (apply_to == "both") or (apply_to == side)
aop_cfg["enabled"] = bool(side_enable and prev_enabled)
setattr(model.encoder, "aop_prune_config", aop_cfg)
try:
yield
finally:
if isinstance(aop_cfg, dict) and prev_enabled is not None:
aop_cfg["enabled"] = prev_enabled
setattr(model.encoder, "aop_prune_config", aop_cfg)