File size: 5,743 Bytes
0a937d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
# src/utils/aop_utils.py
import os, json
from contextlib import contextmanager

def _parse_bool(v: str, default=False):
    if v is None: return default
    v = v.strip().lower()
    return v in {"1","true","yes","y","t","on"}

def _parse_float(v: str, default=None):
    try: return float(v) if v is not None else default
    except: return default

def _parse_int(v: str, default=None):
    try: return int(v) if v is not None else default
    except: return default

def get_env_aop_config():
    enabled     = _parse_bool(os.environ.get("AOP_ENABLED"), False)
    apply_to    = os.environ.get("AOP_APPLY", "qry").strip().lower()  # qry|cand|both
    layer_idx   = _parse_int(os.environ.get("AOP_LAYER"), None)
    mode        = os.environ.get("AOP_MODE", "delta").strip().lower()

    delta       = _parse_float(os.environ.get("AOP_DELTA"), 0.10)
    khat        = _parse_float(os.environ.get("AOP_KHAT"), 1.0)
    keep_ratio  = _parse_float(os.environ.get("AOP_KEEP_RATIO"), 1.0)
    min_keep    = _parse_int(os.environ.get("AOP_MIN_KEEP"), 64)
    use_bias    = _parse_bool(os.environ.get("AOP_USE_BIAS"), True)

    prune_vision = _parse_bool(os.environ.get("AOP_PRUNE_VISION"), True)
    prune_text   = _parse_bool(os.environ.get("AOP_PRUNE_TEXT"), False)

    delta_v     = _parse_float(os.environ.get("AOP_DELTA_VISION"), None)
    khat_v      = _parse_float(os.environ.get("AOP_KHAT_VISION"), None)
    keep_ratio_v= _parse_float(os.environ.get("AOP_KEEP_RATIO_VISION"), None)
    min_keep_v  = _parse_int(os.environ.get("AOP_MIN_KEEP_VISION"), None)

    delta_t     = _parse_float(os.environ.get("AOP_DELTA_TEXT"), None)
    khat_t      = _parse_float(os.environ.get("AOP_KHAT_TEXT"), None)
    keep_ratio_t= _parse_float(os.environ.get("AOP_KEEP_RATIO_TEXT"), None)
    min_keep_t  = _parse_int(os.environ.get("AOP_MIN_KEEP_TEXT"), 32)

    protect_text_last = _parse_int(os.environ.get("AOP_PROTECT_TEXT_LAST"), 16)
    protect_special   = _parse_bool(os.environ.get("AOP_PROTECT_SPECIAL"), True)

    margin_src  = os.environ.get("AOP_MARGIN", "").strip().lower()
    attn_impl   = os.environ.get("AOP_ATTN_IMPL", "").strip().lower() # "" | "sdpa"

    selection = os.environ.get("AOP_SELECTION", "aop").strip().lower()  # aop|attention|random
    if _parse_bool(os.environ.get("AOP_RANDOM"), False):
        selection = "random"
    random_seed = _parse_int(os.environ.get("AOP_RANDOM_SEED"), None)
    attn_agg    = os.environ.get("AOP_ATTENTION_AGG", "mean").strip().lower()  # mean|max|sum

    if layer_idx is None and enabled:
        enabled = False

    return {
        "enabled": enabled,
        "apply_to": apply_to,
        "layer_idx": layer_idx,
        "mode": mode,

        "delta": delta, "K_hat": khat,
        "keep_ratio": keep_ratio, "min_keep": min_keep,
        "use_bias": use_bias, "eps": 1e-6,

        "prune_vision": prune_vision,
        "prune_text": prune_text,

        "delta_vision": delta_v,
        "K_hat_vision": khat_v,
        "keep_ratio_vision": keep_ratio_v,
        "min_keep_vision": min_keep_v,

        "delta_text": delta_t,
        "K_hat_text": khat_t,
        "keep_ratio_text": keep_ratio_t,
        "min_keep_text": min_keep_t,

        "protect_text_last": protect_text_last,
        "protect_special": protect_special,

        "margin_mid": None if margin_src != "mid" else "USE_MID_MARGIN",
        "epsilon_hat": None,
        "attn_impl_override": attn_impl if attn_impl in {"sdpa"} else "",

        "selection": selection,
        "random_seed": random_seed,
        "attn_agg": attn_agg,
    }

def apply_aop_to_model(model):
    """
    注入 AOP 配置到底模:model.encoder.aop_prune_config
    如需 attention 重要性,支持 AOP_ATTN_IMPL=sdpa(因为 flash_attn2 不输出 attn)
    """
    aop_cfg = get_env_aop_config()
    if not aop_cfg["enabled"]:
        print("[AOP] disabled")
        return aop_cfg

    setattr(model.encoder, "aop_prune_config", aop_cfg)
    attn_override = aop_cfg.get("attn_impl_override", "")
    if attn_override:
        try:
            if hasattr(model.encoder, "model") and hasattr(model.encoder.model, "config"):
                prev = model.encoder.model.config._attn_implementation
                model.encoder.model.config._attn_implementation = attn_override
                print(f"[AOP] override attn impl: {prev} -> {attn_override}")
        except Exception as e:
            print(f"[AOP] try override attn impl failed: {e}")
    print("[AOP] config:", json.dumps({
        "apply_to": aop_cfg["apply_to"], "layer_idx": aop_cfg["layer_idx"], "mode": aop_cfg["mode"],
        "prune_text": aop_cfg.get("prune_text", False),
        "keep_ratio_text": aop_cfg.get("keep_ratio_text", None),
        "keep_ratio_vision": aop_cfg.get("keep_ratio_vision", None),
        "selection": aop_cfg.get("selection", "aop"),
        "attn_agg": aop_cfg.get("attn_agg", "mean"),
    }))
    return aop_cfg

@contextmanager
def aop_side(model, side: str):
    """
    with aop_side(model, "qry"): 仅在 qry 侧启用(若 AOP_APPLY 包含该侧),退出自动恢复
    """
    aop_cfg = getattr(model.encoder, "aop_prune_config", None)
    prev_enabled = None
    if isinstance(aop_cfg, dict) and aop_cfg:
        prev_enabled = aop_cfg.get("enabled", False)
        apply_to = aop_cfg.get("apply_to", "qry")
        side_enable = (apply_to == "both") or (apply_to == side)
        aop_cfg["enabled"] = bool(side_enable and prev_enabled)
        setattr(model.encoder, "aop_prune_config", aop_cfg)
    try:
        yield
    finally:
        if isinstance(aop_cfg, dict) and prev_enabled is not None:
            aop_cfg["enabled"] = prev_enabled
            setattr(model.encoder, "aop_prune_config", aop_cfg)