|
|
import os |
|
|
import json |
|
|
import math |
|
|
import time |
|
|
import random |
|
|
import datetime |
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.distributed as dist |
|
|
from tqdm import tqdm |
|
|
from torch.utils.data import DataLoader |
|
|
from transformers import HfArgumentParser, AutoConfig |
|
|
from sklearn.model_selection import train_test_split |
|
|
import yaml |
|
|
from datasets import concatenate_datasets |
|
|
|
|
|
from src.arguments import ModelArguments, DataArguments, TrainingArguments |
|
|
from src.data.collator.eval_collator import MultimodalEvalDataCollator |
|
|
from src.data.eval_dataset.base_eval_dataset import AutoEvalPairDataset, generate_cand_dataset |
|
|
from src.model.model_cut_layer_AOP_add_text_cut import MMEBModel |
|
|
from src.model.processor import get_backbone_name, load_processor |
|
|
from src.utils import batch_to_device, print_master |
|
|
|
|
|
|
|
|
|
|
|
def _parse_bool(v: str, default=False): |
|
|
if v is None: |
|
|
return default |
|
|
v = v.strip().lower() |
|
|
return v in {"1", "true", "yes", "y", "t", "on"} |
|
|
|
|
|
|
|
|
def _parse_int(v: str, default=None): |
|
|
try: |
|
|
return int(v) if v is not None else default |
|
|
except Exception: |
|
|
return default |
|
|
|
|
|
|
|
|
def _parse_float(v: str, default=None): |
|
|
try: |
|
|
return float(v) if v is not None else default |
|
|
except Exception: |
|
|
return default |
|
|
|
|
|
|
|
|
def get_env_aop_config(): |
|
|
enabled = _parse_bool(os.environ.get("AOP_ENABLED"), False) |
|
|
apply_to = os.environ.get("AOP_APPLY", "qry").strip().lower() |
|
|
layer_idx = _parse_int(os.environ.get("AOP_LAYER"), None) |
|
|
mode = os.environ.get("AOP_MODE", "ratio").strip().lower() |
|
|
prune_vision = _parse_bool(os.environ.get("AOP_PRUNE_VISION"), True) |
|
|
prune_text = _parse_bool(os.environ.get("AOP_PRUNE_TEXT"), False) |
|
|
keep_ratio_v = _parse_float(os.environ.get("AOP_KEEP_RATIO_VISION"), None) |
|
|
keep_ratio_t = _parse_float(os.environ.get("AOP_KEEP_RATIO_TEXT"), None) |
|
|
attn_agg = os.environ.get("AOP_ATTENTION_AGG", "mean").strip().lower() |
|
|
ee_layer = _parse_int(os.environ.get("EE_LAYER"), None) |
|
|
|
|
|
return { |
|
|
"enabled": enabled, |
|
|
"apply_to": apply_to, |
|
|
"layer_idx": layer_idx, |
|
|
"mode": mode, |
|
|
"prune_vision": prune_vision, |
|
|
"prune_text": prune_text, |
|
|
"keep_ratio_vision": keep_ratio_v, |
|
|
"keep_ratio_text": keep_ratio_t, |
|
|
"attn_agg": attn_agg, |
|
|
"ee_layer": ee_layer, |
|
|
} |
|
|
|
|
|
|
|
|
def pad_dataset_to_divisible(dataset, world_size): |
|
|
num_samples = len(dataset) |
|
|
if num_samples % world_size == 0: |
|
|
return dataset, num_samples |
|
|
num_to_add = world_size - (num_samples % world_size) |
|
|
padding_data = dataset.select([i % len(dataset) for i in range(num_to_add)]) |
|
|
padded_dataset = concatenate_datasets([dataset, padding_data]) |
|
|
return padded_dataset, num_samples + num_to_add |
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def encode_candidates_both_layers(model: MMEBModel, loader: DataLoader, training_args: TrainingArguments, mid_layer: int): |
|
|
model.eval() |
|
|
all_mid, all_last, all_ids = [], [], [] |
|
|
for inputs, infos in tqdm(loader, desc="[DUMP] Cands[BOTH]", disable=False): |
|
|
inputs = batch_to_device(inputs, training_args.device) |
|
|
|
|
|
aop_cfg = getattr(model.encoder, "aop_prune_config", None) |
|
|
if isinstance(aop_cfg, dict) and aop_cfg: |
|
|
aop_off = dict(aop_cfg) |
|
|
aop_off["enabled"] = False |
|
|
setattr(model.encoder, "aop_prune_config", aop_off) |
|
|
|
|
|
with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): |
|
|
out = model.encoder( |
|
|
**inputs, |
|
|
return_dict=True, |
|
|
output_hidden_states=True, |
|
|
stop_at_layer=None, |
|
|
compute_lm_head=False, |
|
|
) |
|
|
hs_list = out.hidden_states |
|
|
assert hs_list is not None and len(hs_list) > mid_layer, "hidden_states too short for mid_layer" |
|
|
mid_hs, last_hs = hs_list[mid_layer], hs_list[-1] |
|
|
am = inputs.get("attention_mask", None) |
|
|
if am is not None and hasattr(am, "device") and am.device != mid_hs.device: |
|
|
am = am.to(mid_hs.device) |
|
|
reps_mid = model._pooling(mid_hs, am).detach().float().cpu() |
|
|
reps_last = model._pooling(last_hs, am).detach().float().cpu() |
|
|
all_mid.append(reps_mid) |
|
|
all_last.append(reps_last) |
|
|
all_ids.extend([info["cand_name"] for info in infos]) |
|
|
|
|
|
|
|
|
if isinstance(aop_cfg, dict) and aop_cfg: |
|
|
setattr(model.encoder, "aop_prune_config", aop_cfg) |
|
|
|
|
|
cand_mid = torch.cat(all_mid, dim=0).numpy() |
|
|
cand_last = torch.cat(all_last, dim=0).numpy() |
|
|
return cand_mid, cand_last, all_ids |
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def build_phaseA_features_global( |
|
|
reps_mid_t: torch.Tensor, |
|
|
cand_mid_t: torch.Tensor, |
|
|
am_mid: torch.Tensor, |
|
|
input_ids: torch.Tensor, |
|
|
cfg, |
|
|
topk: int = 200, |
|
|
temp: float = 0.05, |
|
|
): |
|
|
device = reps_mid_t.device |
|
|
B = reps_mid_t.size(0) |
|
|
|
|
|
scores_t = reps_mid_t @ cand_mid_t.T |
|
|
k = min(topk, scores_t.size(1)) |
|
|
vals_t, _ = torch.topk(scores_t, k=k, dim=1) |
|
|
s1 = vals_t[:, 0] |
|
|
s2 = vals_t[:, 1] if k >= 2 else torch.zeros_like(s1) |
|
|
margin = s1 - s2 |
|
|
p_t = torch.softmax(vals_t / max(temp, 1e-6), dim=1) |
|
|
H = -(p_t * (torch.log(p_t + 1e-12))).sum(dim=1) / math.log(max(k, 1)) |
|
|
sum_p2 = (p_t**2).sum(dim=1) |
|
|
|
|
|
|
|
|
am = am_mid.to(torch.bool) |
|
|
iid = input_ids |
|
|
image_token_id = getattr(cfg, "image_token_id", None) |
|
|
video_token_id = getattr(cfg, "video_token_id", None) |
|
|
bos_id = getattr(cfg, "bos_token_id", None) |
|
|
eos_id = getattr(cfg, "eos_token_id", None) |
|
|
pad_id = getattr(cfg, "pad_token_id", None) |
|
|
is_image = (iid == image_token_id) if (image_token_id is not None and image_token_id >= 0) else torch.zeros_like(iid, dtype=torch.bool) |
|
|
is_video = (iid == video_token_id) if (video_token_id is not None and video_token_id >= 0) else torch.zeros_like(iid, dtype=torch.bool) |
|
|
is_vision = (is_image | is_video) & am |
|
|
|
|
|
is_special = torch.zeros_like(iid, dtype=torch.bool) |
|
|
for tid in [bos_id, eos_id, pad_id]: |
|
|
if tid is not None and tid >= 0: |
|
|
is_special |= (iid == tid) |
|
|
is_text = am & (~is_vision) & (~is_special) |
|
|
|
|
|
L_vis = is_vision.sum(dim=1).float() |
|
|
L_txt = is_text.sum(dim=1).float() |
|
|
L_tot = am.sum(dim=1).float().clamp(min=1.0) |
|
|
r_vis = L_vis / L_tot |
|
|
r_txt = L_txt / L_tot |
|
|
|
|
|
|
|
|
is_I = ((L_vis > 0) & (L_txt == 0)).float() |
|
|
is_T = ((L_txt > 0) & (L_vis == 0)).float() |
|
|
is_IT = ((L_txt > 0) & (L_vis > 0)).float() |
|
|
|
|
|
feats = torch.stack([s1, s2, margin, H, sum_p2, L_txt, L_vis, L_tot, r_txt, r_vis, is_I, is_T, is_IT], dim=1) |
|
|
return feats |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def build_phaseA_features_local( |
|
|
reps_mid_t: torch.Tensor, |
|
|
cand_mid_t: torch.Tensor, |
|
|
am_mid: torch.Tensor, |
|
|
input_ids: torch.Tensor, |
|
|
cfg, |
|
|
per_sample_rows: list, |
|
|
topk: int = 200, |
|
|
temp: float = 0.05, |
|
|
): |
|
|
device = reps_mid_t.device |
|
|
B = reps_mid_t.size(0) |
|
|
s1_list, s2_list, H_list, sum_p2_list = [], [], [], [] |
|
|
for b in range(B): |
|
|
rows = per_sample_rows[b] |
|
|
if len(rows) == 0: |
|
|
s1_list.append(torch.tensor(0.0, device=device)) |
|
|
s2_list.append(torch.tensor(0.0, device=device)) |
|
|
H_list.append(torch.tensor(1.0, device=device)) |
|
|
sum_p2_list.append(torch.tensor(0.0, device=device)) |
|
|
continue |
|
|
cmat = cand_mid_t[rows] |
|
|
sv = (reps_mid_t[b:b+1] @ cmat.T)[0] |
|
|
k = min(topk, sv.size(0)) |
|
|
vals, _ = torch.topk(sv, k=k, dim=0) |
|
|
s1_list.append(vals[0]) |
|
|
s2_list.append(vals[1] if k >= 2 else torch.tensor(0.0, device=device, dtype=vals.dtype)) |
|
|
p = torch.softmax(vals / max(temp, 1e-6), dim=0) |
|
|
H_list.append((-(p * (torch.log(p + 1e-12))).sum() / math.log(max(k, 1)))) |
|
|
sum_p2_list.append((p**2).sum()) |
|
|
s1 = torch.stack(s1_list) |
|
|
s2 = torch.stack(s2_list) |
|
|
H = torch.stack(H_list) |
|
|
sum_p2 = torch.stack(sum_p2_list) |
|
|
margin = s1 - s2 |
|
|
|
|
|
am = am_mid.to(torch.bool) |
|
|
iid = input_ids |
|
|
image_token_id = getattr(cfg, "image_token_id", None) |
|
|
video_token_id = getattr(cfg, "video_token_id", None) |
|
|
bos_id = getattr(cfg, "bos_token_id", None) |
|
|
eos_id = getattr(cfg, "eos_token_id", None) |
|
|
pad_id = getattr(cfg, "pad_token_id", None) |
|
|
is_image = (iid == image_token_id) if (image_token_id is not None and image_token_id >= 0) else torch.zeros_like(iid, dtype=torch.bool) |
|
|
is_video = (iid == video_token_id) if (video_token_id is not None and video_token_id >= 0) else torch.zeros_like(iid, dtype=torch.bool) |
|
|
is_vision = (is_image | is_video) & am |
|
|
|
|
|
is_special = torch.zeros_like(iid, dtype=torch.bool) |
|
|
for tid in [bos_id, eos_id, pad_id]: |
|
|
if tid is not None and tid >= 0: |
|
|
is_special |= (iid == tid) |
|
|
is_text = am & (~is_vision) & (~is_special) |
|
|
|
|
|
L_vis = is_vision.sum(dim=1).float() |
|
|
L_txt = is_text.sum(dim=1).float() |
|
|
L_tot = am.sum(dim=1).float().clamp(min=1.0) |
|
|
r_vis = L_vis / L_tot |
|
|
r_txt = L_txt / L_tot |
|
|
|
|
|
is_I = ((L_vis > 0) & (L_txt == 0)).float() |
|
|
is_T = ((L_txt > 0) & (L_vis == 0)).float() |
|
|
is_IT = ((L_txt > 0) & (L_vis > 0)).float() |
|
|
|
|
|
feats = torch.stack([s1, s2, margin, H, sum_p2, L_txt, L_vis, L_tot, r_txt, r_vis, is_I, is_T, is_IT], dim=1) |
|
|
return feats |
|
|
|
|
|
|
|
|
|
|
|
def compute_label_top1_equal_global(scores_mid: np.ndarray, scores_last: np.ndarray) -> np.ndarray: |
|
|
top1_mid = scores_mid.argmax(axis=1) |
|
|
top1_last = scores_last.argmax(axis=1) |
|
|
return (top1_mid == top1_last).astype(np.int32) |
|
|
|
|
|
|
|
|
def compute_label_top1_equal_local(scores_mid_list, scores_last_list): |
|
|
y = [] |
|
|
for sv_mid, sv_last in zip(scores_mid_list, scores_last_list): |
|
|
if sv_mid.size == 0 or sv_last.size == 0: |
|
|
y.append(0) |
|
|
else: |
|
|
y.append(int(int(sv_mid.argmax()) == int(sv_last.argmax()))) |
|
|
return np.array(y, dtype=np.int32) |
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
|
|
|
if "RANK" in os.environ and dist.is_available() and not dist.is_initialized(): |
|
|
dist.init_process_group(backend="nccl", timeout=datetime.timedelta(minutes=60)) |
|
|
local_rank = dist.get_rank() if dist.is_initialized() else 0 |
|
|
|
|
|
parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) |
|
|
model_args, data_args, training_args = parser.parse_args_into_dataclasses() |
|
|
os.makedirs(data_args.encode_output_path, exist_ok=True) |
|
|
|
|
|
|
|
|
hf_config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) |
|
|
if not getattr(model_args, "model_backbone", None): |
|
|
model_backbone = get_backbone_name(hf_config=hf_config, model_type=model_args.model_type) |
|
|
setattr(model_args, 'model_backbone', model_backbone) |
|
|
setattr(training_args, 'model_backbone', model_backbone) |
|
|
|
|
|
if local_rank == 0: |
|
|
processor = load_processor(model_args, data_args) |
|
|
model = MMEBModel.load(model_args, is_trainable=False, processor=processor) |
|
|
if dist.is_initialized(): |
|
|
dist.barrier() |
|
|
if local_rank != 0: |
|
|
processor = load_processor(model_args, data_args) |
|
|
time.sleep(random.randint(2 * local_rank, 3 * local_rank)) |
|
|
model = MMEBModel.load(model_args, is_trainable=False, processor=processor) |
|
|
|
|
|
model.eval() |
|
|
model = model.to(training_args.device, dtype=torch.bfloat16) |
|
|
|
|
|
|
|
|
ee_layer = int(os.environ.get("EE_LAYER", os.environ.get("AOP_LAYER", "12"))) |
|
|
feat_topk = int(os.environ.get("EE_FEAT_TOPK", "200")) |
|
|
force_no_aop = os.environ.get("DUMP_EXIT_NO_AOP", "1").strip().lower() in {"1", "true", "yes", "on"} |
|
|
|
|
|
|
|
|
TRAIN_RATIO = 0.1 |
|
|
VAL_RATIO = 0.1 |
|
|
|
|
|
|
|
|
with open(data_args.dataset_config, 'r', encoding='utf-8') as yf: |
|
|
dataset_configs = yaml.safe_load(yf) |
|
|
|
|
|
for dataset_name, task_cfg in dataset_configs.items(): |
|
|
if dist.is_initialized(): dist.barrier() |
|
|
print_master(f"\n[DUMP] Processing {dataset_name} ...") |
|
|
|
|
|
if data_args.data_basedir: |
|
|
for key in ["image_root", "video_root", "frame_root", "clip_root", "data_path"]: |
|
|
if task_cfg.get(key): |
|
|
task_cfg[key] = os.path.join(data_args.data_basedir, task_cfg[key]) |
|
|
|
|
|
|
|
|
full_qry, corpus = AutoEvalPairDataset.instantiate(model_args=model_args, data_args=data_args, **task_cfg) |
|
|
full_cand = generate_cand_dataset(full_qry, corpus) |
|
|
|
|
|
|
|
|
|
|
|
total_len = len(full_qry) |
|
|
all_indices = np.arange(total_len) |
|
|
|
|
|
|
|
|
train_idxs, temp_idxs = train_test_split( |
|
|
all_indices, train_size=TRAIN_RATIO, random_state=42, shuffle=True |
|
|
) |
|
|
val_relative_ratio = VAL_RATIO / (1.0 - TRAIN_RATIO) |
|
|
val_idxs, test_idxs = train_test_split( |
|
|
temp_idxs, train_size=val_relative_ratio, random_state=42, shuffle=True |
|
|
) |
|
|
|
|
|
print_master(f"[DUMP] Split sizes -> Train: {len(train_idxs)}, Val: {len(val_idxs)}, Test: {len(test_idxs)}") |
|
|
|
|
|
|
|
|
splits = { |
|
|
"train": {"ds": full_qry.select(train_idxs), "indices": train_idxs}, |
|
|
"val": {"ds": full_qry.select(val_idxs), "indices": val_idxs}, |
|
|
"test": {"ds": full_qry.select(test_idxs), "indices": test_idxs} |
|
|
} |
|
|
|
|
|
|
|
|
cand_collator = MultimodalEvalDataCollator(processor, model_args, data_args, "cand") |
|
|
cand_loader = DataLoader( |
|
|
full_cand, batch_size=training_args.per_device_eval_batch_size, |
|
|
collate_fn=cand_collator, num_workers=training_args.dataloader_num_workers |
|
|
) |
|
|
cand_mid_np, cand_last_np, cand_ids = encode_candidates_both_layers(model, cand_loader, training_args, mid_layer=ee_layer) |
|
|
cand_id2row = {str(cid): i for i, cid in enumerate(cand_ids)} |
|
|
device = training_args.device |
|
|
cand_mid_t = torch.from_numpy(cand_mid_np).to(device=device, dtype=torch.bfloat16) |
|
|
cand_last_t = None |
|
|
|
|
|
|
|
|
sum_feat, sum2_feat, n_feat = None, None, 0 |
|
|
scaler_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_phaseA_scaler.json") |
|
|
|
|
|
for split_name, split_info in splits.items(): |
|
|
qry_dataset = split_info["ds"] |
|
|
global_indices = split_info["indices"] |
|
|
|
|
|
if len(qry_dataset) == 0: continue |
|
|
|
|
|
|
|
|
if dist.is_initialized(): |
|
|
world_size = dist.get_world_size() |
|
|
per_rank = len(qry_dataset) // world_size |
|
|
start_idx = local_rank * per_rank |
|
|
end_idx = start_idx + per_rank |
|
|
|
|
|
if start_idx >= len(qry_dataset): |
|
|
local_dataset = qry_dataset.select([]) |
|
|
local_indices = [] |
|
|
else: |
|
|
local_dataset = qry_dataset.select(range(start_idx, end_idx)) |
|
|
local_indices = global_indices[start_idx : end_idx] |
|
|
else: |
|
|
local_dataset = qry_dataset |
|
|
local_indices = global_indices |
|
|
|
|
|
qry_collator = MultimodalEvalDataCollator(processor, model_args, data_args, "qry") |
|
|
qry_loader = DataLoader( |
|
|
local_dataset, |
|
|
batch_size=training_args.per_device_eval_batch_size, |
|
|
collate_fn=qry_collator, |
|
|
num_workers=training_args.dataloader_num_workers, |
|
|
shuffle=False |
|
|
) |
|
|
|
|
|
feat_out_path_rank = os.path.join(data_args.encode_output_path, f"{dataset_name}_{split_name}_features.jsonl.rank{local_rank}") |
|
|
print_master(f" -> Dump {split_name} features to {feat_out_path_rank} ...") |
|
|
|
|
|
|
|
|
cursor = 0 |
|
|
|
|
|
with open(feat_out_path_rank, "w", encoding="utf-8") as fout: |
|
|
for inputs, infos in tqdm(qry_loader, desc=f"[{split_name.upper()}]", disable=(local_rank!=0)): |
|
|
inputs = batch_to_device(inputs, device) |
|
|
B = inputs["input_ids"].size(0) |
|
|
|
|
|
|
|
|
batch_global_ids = local_indices[cursor : cursor + B] |
|
|
cursor += B |
|
|
|
|
|
|
|
|
aop_cfg_cur = getattr(model.encoder, "aop_prune_config", None) |
|
|
orig_aop = None |
|
|
if force_no_aop and isinstance(aop_cfg_cur, dict): |
|
|
orig_aop = dict(aop_cfg_cur) |
|
|
aop_off = dict(aop_cfg_cur) |
|
|
aop_off["enabled"] = False |
|
|
setattr(model.encoder, "aop_prune_config", aop_off) |
|
|
|
|
|
with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): |
|
|
out_mid = model.encoder( |
|
|
**inputs, return_dict=True, output_hidden_states=False, |
|
|
stop_at_layer=int(ee_layer), compute_lm_head=False, |
|
|
return_intermediate_state=True |
|
|
) |
|
|
if orig_aop is not None: setattr(model.encoder, "aop_prune_config", orig_aop) |
|
|
|
|
|
|
|
|
hs_mid = getattr(out_mid, "last_hidden_state", None) |
|
|
if hs_mid is None: hs_mid = out_mid.hidden_states[-1] |
|
|
am_mid = getattr(out_mid, "attention_mask", None) |
|
|
if am_mid is None: am_mid = inputs.get("attention_mask") |
|
|
if hasattr(am_mid, "device") and am_mid.device != hs_mid.device: am_mid = am_mid.to(hs_mid.device) |
|
|
reps_mid_t = model._pooling(hs_mid, am_mid).detach().to(device=device, dtype=torch.bfloat16) |
|
|
|
|
|
|
|
|
rank_global = task_cfg.get("eval_type", "global") == "global" |
|
|
if rank_global: |
|
|
feats_t = build_phaseA_features_global(reps_mid_t, cand_mid_t, am_mid, inputs["input_ids"], model.encoder.config, topk=feat_topk) |
|
|
else: |
|
|
rows_list = [] |
|
|
for b_idx in range(B): |
|
|
cand_local = infos[b_idx]["cand_names"] |
|
|
rows = [cand_id2row.get(str(cid), -1) for cid in cand_local] |
|
|
rows = [r for r in rows if r >= 0] |
|
|
rows_list.append(rows) |
|
|
feats_t = build_phaseA_features_local(reps_mid_t, cand_mid_t, am_mid, inputs["input_ids"], model.encoder.config, rows_list, topk=feat_topk) |
|
|
feats_np = feats_t.detach().float().cpu().numpy() |
|
|
|
|
|
|
|
|
interm = getattr(out_mid, "intermediate_state", None) |
|
|
resume_state = { |
|
|
"hidden_states": interm["hidden_states"].detach(), |
|
|
"attention_mask": interm["attention_mask"].detach(), |
|
|
"position_ids": interm["position_ids"].detach(), |
|
|
"vision_mask": interm.get("vision_mask"), |
|
|
"text_mask": interm.get("text_mask"), |
|
|
"next_layer_idx": int(interm["next_layer_idx"]) |
|
|
} |
|
|
aop_cfg_cur = getattr(model.encoder, "aop_prune_config", None) |
|
|
orig_aop2 = None |
|
|
if force_no_aop and isinstance(aop_cfg_cur, dict): |
|
|
orig_aop2 = dict(aop_cfg_cur) |
|
|
aop_off2 = dict(aop_cfg_cur) |
|
|
aop_off2["enabled"] = False |
|
|
setattr(model.encoder, "aop_prune_config", aop_off2) |
|
|
|
|
|
with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): |
|
|
out_last = model.encoder( |
|
|
return_dict=True, output_hidden_states=False, stop_at_layer=None, |
|
|
resume_state=resume_state, compute_lm_head=False |
|
|
) |
|
|
if orig_aop2 is not None: setattr(model.encoder, "aop_prune_config", orig_aop2) |
|
|
|
|
|
|
|
|
hs_last = getattr(out_last, "last_hidden_state", None) |
|
|
if hs_last is None: hs_last = out_last.hidden_states[-1] |
|
|
am_last = getattr(out_last, "attention_mask", None) |
|
|
if am_last is None: am_last = resume_state["attention_mask"] |
|
|
if hasattr(am_last, "device") and am_last.device != hs_last.device: am_last = am_last.to(hs_last.device) |
|
|
reps_last_t = model._pooling(hs_last, am_last).detach().to(device=device, dtype=torch.bfloat16) |
|
|
|
|
|
|
|
|
if rank_global: |
|
|
if cand_last_t is None: cand_last_t = torch.from_numpy(cand_last_np).to(device=device, dtype=torch.bfloat16) |
|
|
sim_mid = (reps_mid_t @ cand_mid_t.T).detach().float().cpu().numpy() |
|
|
sim_last = (reps_last_t @ cand_last_t.T).detach().float().cpu().numpy() |
|
|
y = compute_label_top1_equal_global(sim_mid, sim_last) |
|
|
else: |
|
|
y_list = [] |
|
|
for b_idx in range(B): |
|
|
cand_local = infos[b_idx]["cand_names"] |
|
|
rows = [cand_id2row.get(str(cid), -1) for cid in cand_local] |
|
|
rows = [r for r in rows if r >= 0] |
|
|
if not rows: |
|
|
y_list.append(0) |
|
|
continue |
|
|
c_mid = cand_mid_t[rows] |
|
|
if cand_last_t is None: cand_last_t = torch.from_numpy(cand_last_np).to(device=device, dtype=torch.bfloat16) |
|
|
c_last = cand_last_t[rows] |
|
|
sv_mid = (reps_mid_t[b_idx:b_idx+1] @ c_mid.T)[0].detach().float().cpu().numpy() |
|
|
sv_last = (reps_last_t[b_idx:b_idx+1] @ c_last.T)[0].detach().float().cpu().numpy() |
|
|
y_list.append(int(int(sv_mid.argmax()) == int(sv_last.argmax()))) |
|
|
y = np.array(y_list, dtype=np.int32) |
|
|
|
|
|
|
|
|
|
|
|
if split_name == "train": |
|
|
if sum_feat is None: |
|
|
sum_feat = feats_np.sum(axis=0) |
|
|
sum2_feat = (feats_np**2).sum(axis=0) |
|
|
else: |
|
|
sum_feat += feats_np.sum(axis=0) |
|
|
sum2_feat += (feats_np**2).sum(axis=0) |
|
|
n_feat += feats_np.shape[0] |
|
|
|
|
|
L_txt = feats_np[:, 5] |
|
|
L_vis = feats_np[:, 6] |
|
|
types = np.where((L_vis > 0) & (L_txt == 0), "I", np.where((L_txt > 0) & (L_vis == 0), "T", "IT")) |
|
|
|
|
|
for b_idx in range(B): |
|
|
row = { |
|
|
"dataset": dataset_name, |
|
|
"split": split_name, |
|
|
"qid": int(batch_global_ids[b_idx]), |
|
|
"type": str(types[b_idx]), |
|
|
"feats": feats_np[b_idx].tolist(), |
|
|
"y_exit": int(y[b_idx]), |
|
|
} |
|
|
fout.write(json.dumps(row, ensure_ascii=False) + "\n") |
|
|
|
|
|
|
|
|
if dist.is_initialized(): |
|
|
dist.barrier() |
|
|
stats_vec = torch.tensor( |
|
|
np.concatenate([sum_feat, sum2_feat, [n_feat]]) if n_feat > 0 else np.zeros(13*2+1), |
|
|
device=device, dtype=torch.float64 |
|
|
) |
|
|
dist.all_reduce(stats_vec, op=dist.ReduceOp.SUM) |
|
|
sum_feat_all = stats_vec[:13].cpu().numpy() |
|
|
sum2_feat_all = stats_vec[13:26].cpu().numpy() |
|
|
n_feat_all = stats_vec[26].item() |
|
|
else: |
|
|
sum_feat_all = sum_feat |
|
|
sum2_feat_all = sum2_feat |
|
|
n_feat_all = n_feat |
|
|
|
|
|
if local_rank == 0 and n_feat_all > 0: |
|
|
mean = (sum_feat_all / n_feat_all).tolist() |
|
|
var = (sum2_feat_all / n_feat_all - (sum_feat_all / n_feat_all) ** 2) |
|
|
std = [float(max(1e-6, math.sqrt(max(0.0, v)))) for v in var.tolist()] |
|
|
with open(scaler_path, "w", encoding="utf-8") as f: |
|
|
json.dump({"mean": mean, "std": std, "in_dim": len(mean), "n_samples": n_feat_all, "dataset": dataset_name}, f, indent=2) |
|
|
print_master(f"[DUMP] {dataset_name} Scaler saved -> {scaler_path}") |
|
|
|
|
|
if dist.is_initialized(): dist.barrier() |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |