| | import os |
| | import json |
| | import math |
| | import time |
| | import random |
| | import datetime |
| | import numpy as np |
| | import torch |
| | import torch.distributed as dist |
| | from tqdm import tqdm |
| | from torch.utils.data import DataLoader |
| | from transformers import HfArgumentParser, AutoConfig |
| | import yaml |
| |
|
| | from src.arguments import ModelArguments, DataArguments, TrainingArguments |
| | from src.data.collator.eval_collator import MultimodalEvalDataCollator |
| | from src.data.eval_dataset.base_eval_dataset import AutoEvalPairDataset, generate_cand_dataset |
| | from src.model.model_cut_layer_AOP_add_text_cut import MMEBModel |
| | from src.model.processor import get_backbone_name, load_processor |
| | from src.utils import batch_to_device, print_master |
| |
|
| |
|
| | |
| | def _parse_bool(v: str, default=False): |
| | if v is None: |
| | return default |
| | v = v.strip().lower() |
| | return v in {"1", "true", "yes", "y", "t", "on"} |
| |
|
| |
|
| | def _parse_int(v: str, default=None): |
| | try: |
| | return int(v) if v is not None else default |
| | except Exception: |
| | return default |
| |
|
| |
|
| | def _parse_float(v: str, default=None): |
| | try: |
| | return float(v) if v is not None else default |
| | except Exception: |
| | return default |
| |
|
| |
|
| | def get_env_aop_config(): |
| | enabled = _parse_bool(os.environ.get("AOP_ENABLED"), False) |
| | apply_to = os.environ.get("AOP_APPLY", "qry").strip().lower() |
| | layer_idx = _parse_int(os.environ.get("AOP_LAYER"), None) |
| | mode = os.environ.get("AOP_MODE", "ratio").strip().lower() |
| | prune_vision = _parse_bool(os.environ.get("AOP_PRUNE_VISION"), True) |
| | prune_text = _parse_bool(os.environ.get("AOP_PRUNE_TEXT"), False) |
| | keep_ratio_v = _parse_float(os.environ.get("AOP_KEEP_RATIO_VISION"), None) |
| | keep_ratio_t = _parse_float(os.environ.get("AOP_KEEP_RATIO_TEXT"), None) |
| | attn_agg = os.environ.get("AOP_ATTENTION_AGG", "mean").strip().lower() |
| | ee_layer = _parse_int(os.environ.get("EE_LAYER"), None) |
| |
|
| | return { |
| | "enabled": enabled, |
| | "apply_to": apply_to, |
| | "layer_idx": layer_idx, |
| | "mode": mode, |
| | "prune_vision": prune_vision, |
| | "prune_text": prune_text, |
| | "keep_ratio_vision": keep_ratio_v, |
| | "keep_ratio_text": keep_ratio_t, |
| | "attn_agg": attn_agg, |
| | "ee_layer": ee_layer, |
| | } |
| |
|
| |
|
| | def pad_dataset_to_divisible(dataset, world_size): |
| | num_samples = len(dataset) |
| | if num_samples % world_size == 0: |
| | return dataset, num_samples |
| | num_to_add = world_size - (num_samples % world_size) |
| | padding_data = dataset.select([i % len(dataset) for i in range(num_to_add)]) |
| | padded_dataset = concatenate_datasets([dataset, padding_data]) |
| | return padded_dataset, num_samples + num_to_add |
| |
|
| |
|
| | |
| | @torch.no_grad() |
| | def encode_candidates_both_layers(model: MMEBModel, loader: DataLoader, training_args: TrainingArguments, mid_layer: int): |
| | model.eval() |
| | all_mid, all_last, all_ids = [], [], [] |
| | for inputs, infos in tqdm(loader, desc="[DUMP] Cands[BOTH]", disable=False): |
| | inputs = batch_to_device(inputs, training_args.device) |
| | |
| | aop_cfg = getattr(model.encoder, "aop_prune_config", None) |
| | if isinstance(aop_cfg, dict) and aop_cfg: |
| | aop_off = dict(aop_cfg) |
| | aop_off["enabled"] = False |
| | setattr(model.encoder, "aop_prune_config", aop_off) |
| |
|
| | with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): |
| | out = model.encoder( |
| | **inputs, |
| | return_dict=True, |
| | output_hidden_states=True, |
| | stop_at_layer=None, |
| | compute_lm_head=False, |
| | ) |
| | hs_list = out.hidden_states |
| | assert hs_list is not None and len(hs_list) > mid_layer, "hidden_states too short for mid_layer" |
| | mid_hs, last_hs = hs_list[mid_layer], hs_list[-1] |
| | am = inputs.get("attention_mask", None) |
| | if am is not None and hasattr(am, "device") and am.device != mid_hs.device: |
| | am = am.to(mid_hs.device) |
| | reps_mid = model._pooling(mid_hs, am).detach().float().cpu() |
| | reps_last = model._pooling(last_hs, am).detach().float().cpu() |
| | all_mid.append(reps_mid) |
| | all_last.append(reps_last) |
| | all_ids.extend([info["cand_name"] for info in infos]) |
| | cand_mid = torch.cat(all_mid, dim=0).numpy() |
| | cand_last = torch.cat(all_last, dim=0).numpy() |
| | return cand_mid, cand_last, all_ids |
| |
|
| |
|
| | |
| | @torch.no_grad() |
| | def build_phaseA_features_global( |
| | reps_mid_t: torch.Tensor, |
| | cand_mid_t: torch.Tensor, |
| | am_mid: torch.Tensor, |
| | input_ids: torch.Tensor, |
| | cfg, |
| | topk: int = 200, |
| | temp: float = 0.05, |
| | ): |
| | device = reps_mid_t.device |
| | B = reps_mid_t.size(0) |
| | |
| | scores_t = reps_mid_t @ cand_mid_t.T |
| | k = min(topk, scores_t.size(1)) |
| | vals_t, _ = torch.topk(scores_t, k=k, dim=1) |
| | s1 = vals_t[:, 0] |
| | s2 = vals_t[:, 1] if k >= 2 else torch.zeros_like(s1) |
| | margin = s1 - s2 |
| | p_t = torch.softmax(vals_t / max(temp, 1e-6), dim=1) |
| | H = -(p_t * (torch.log(p_t + 1e-12))).sum(dim=1) / math.log(max(k, 1)) |
| | sum_p2 = (p_t**2).sum(dim=1) |
| |
|
| | |
| | am = am_mid.to(torch.bool) |
| | iid = input_ids |
| | image_token_id = getattr(cfg, "image_token_id", None) |
| | video_token_id = getattr(cfg, "video_token_id", None) |
| | bos_id = getattr(cfg, "bos_token_id", None) |
| | eos_id = getattr(cfg, "eos_token_id", None) |
| | pad_id = getattr(cfg, "pad_token_id", None) |
| | is_image = (iid == image_token_id) if (image_token_id is not None and image_token_id >= 0) else torch.zeros_like(iid, dtype=torch.bool) |
| | is_video = (iid == video_token_id) if (video_token_id is not None and video_token_id >= 0) else torch.zeros_like(iid, dtype=torch.bool) |
| | is_vision = (is_image | is_video) & am |
| |
|
| | is_special = torch.zeros_like(iid, dtype=torch.bool) |
| | for tid in [bos_id, eos_id, pad_id]: |
| | if tid is not None and tid >= 0: |
| | is_special |= (iid == tid) |
| | is_text = am & (~is_vision) & (~is_special) |
| |
|
| | L_vis = is_vision.sum(dim=1).float() |
| | L_txt = is_text.sum(dim=1).float() |
| | L_tot = am.sum(dim=1).float().clamp(min=1.0) |
| | r_vis = L_vis / L_tot |
| | r_txt = L_txt / L_tot |
| |
|
| | |
| | is_I = ((L_vis > 0) & (L_txt == 0)).float() |
| | is_T = ((L_txt > 0) & (L_vis == 0)).float() |
| | is_IT = ((L_txt > 0) & (L_vis > 0)).float() |
| |
|
| | feats = torch.stack([s1, s2, margin, H, sum_p2, L_txt, L_vis, L_tot, r_txt, r_vis, is_I, is_T, is_IT], dim=1) |
| | return feats |
| |
|
| |
|
| | @torch.no_grad() |
| | def build_phaseA_features_local( |
| | reps_mid_t: torch.Tensor, |
| | cand_mid_t: torch.Tensor, |
| | am_mid: torch.Tensor, |
| | input_ids: torch.Tensor, |
| | cfg, |
| | per_sample_rows: list, |
| | topk: int = 200, |
| | temp: float = 0.05, |
| | ): |
| | device = reps_mid_t.device |
| | B = reps_mid_t.size(0) |
| | s1_list, s2_list, H_list, sum_p2_list = [], [], [], [] |
| | for b in range(B): |
| | rows = per_sample_rows[b] |
| | if len(rows) == 0: |
| | s1_list.append(torch.tensor(0.0, device=device)) |
| | s2_list.append(torch.tensor(0.0, device=device)) |
| | H_list.append(torch.tensor(1.0, device=device)) |
| | sum_p2_list.append(torch.tensor(0.0, device=device)) |
| | continue |
| | cmat = cand_mid_t[rows] |
| | sv = (reps_mid_t[b:b+1] @ cmat.T)[0] |
| | k = min(topk, sv.size(0)) |
| | vals, _ = torch.topk(sv, k=k, dim=0) |
| | s1_list.append(vals[0]) |
| | s2_list.append(vals[1] if k >= 2 else torch.tensor(0.0, device=device, dtype=vals.dtype)) |
| | p = torch.softmax(vals / max(temp, 1e-6), dim=0) |
| | H_list.append((-(p * (torch.log(p + 1e-12))).sum() / math.log(max(k, 1)))) |
| | sum_p2_list.append((p**2).sum()) |
| | s1 = torch.stack(s1_list) |
| | s2 = torch.stack(s2_list) |
| | H = torch.stack(H_list) |
| | sum_p2 = torch.stack(sum_p2_list) |
| | margin = s1 - s2 |
| |
|
| | am = am_mid.to(torch.bool) |
| | iid = input_ids |
| | image_token_id = getattr(cfg, "image_token_id", None) |
| | video_token_id = getattr(cfg, "video_token_id", None) |
| | bos_id = getattr(cfg, "bos_token_id", None) |
| | eos_id = getattr(cfg, "eos_token_id", None) |
| | pad_id = getattr(cfg, "pad_token_id", None) |
| | is_image = (iid == image_token_id) if (image_token_id is not None and image_token_id >= 0) else torch.zeros_like(iid, dtype=torch.bool) |
| | is_video = (iid == video_token_id) if (video_token_id is not None and video_token_id >= 0) else torch.zeros_like(iid, dtype=torch.bool) |
| | is_vision = (is_image | is_video) & am |
| |
|
| | is_special = torch.zeros_like(iid, dtype=torch.bool) |
| | for tid in [bos_id, eos_id, pad_id]: |
| | if tid is not None and tid >= 0: |
| | is_special |= (iid == tid) |
| | is_text = am & (~is_vision) & (~is_special) |
| |
|
| | L_vis = is_vision.sum(dim=1).float() |
| | L_txt = is_text.sum(dim=1).float() |
| | L_tot = am.sum(dim=1).float().clamp(min=1.0) |
| | r_vis = L_vis / L_tot |
| | r_txt = L_txt / L_tot |
| |
|
| | is_I = ((L_vis > 0) & (L_txt == 0)).float() |
| | is_T = ((L_txt > 0) & (L_vis == 0)).float() |
| | is_IT = ((L_txt > 0) & (L_vis > 0)).float() |
| |
|
| | feats = torch.stack([s1, s2, margin, H, sum_p2, L_txt, L_vis, L_tot, r_txt, r_vis, is_I, is_T, is_IT], dim=1) |
| | return feats |
| |
|
| |
|
| | |
| | def compute_label_top1_equal_global(scores_mid: np.ndarray, scores_last: np.ndarray) -> np.ndarray: |
| | top1_mid = scores_mid.argmax(axis=1) |
| | top1_last = scores_last.argmax(axis=1) |
| | return (top1_mid == top1_last).astype(np.int32) |
| |
|
| |
|
| | def compute_label_top1_equal_local(scores_mid_list, scores_last_list): |
| | y = [] |
| | for sv_mid, sv_last in zip(scores_mid_list, scores_last_list): |
| | if sv_mid.size == 0 or sv_last.size == 0: |
| | y.append(0) |
| | else: |
| | y.append(int(int(sv_mid.argmax()) == int(sv_last.argmax()))) |
| | return np.array(y, dtype=np.int32) |
| |
|
| |
|
| | |
| | def main(): |
| | |
| | if "RANK" in os.environ and dist.is_available() and not dist.is_initialized(): |
| | dist.init_process_group(backend="nccl", timeout=datetime.timedelta(minutes=60)) |
| | local_rank = dist.get_rank() if dist.is_initialized() else 0 |
| |
|
| | parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) |
| | model_args, data_args, training_args = parser.parse_args_into_dataclasses() |
| | os.makedirs(data_args.encode_output_path, exist_ok=True) |
| |
|
| | |
| | hf_config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) |
| | if not getattr(model_args, "model_backbone", None): |
| | model_backbone = get_backbone_name(hf_config=hf_config, model_type=model_args.model_type) |
| | setattr(model_args, 'model_backbone', model_backbone) |
| | setattr(training_args, 'model_backbone', model_backbone) |
| |
|
| | if local_rank == 0: |
| | processor = load_processor(model_args, data_args) |
| | model = MMEBModel.load(model_args, is_trainable=False, processor=processor) |
| | if dist.is_initialized(): |
| | dist.barrier() |
| | if local_rank != 0: |
| | processor = load_processor(model_args, data_args) |
| | time.sleep(random.randint(2 * local_rank, 3 * local_rank)) |
| | model = MMEBModel.load(model_args, is_trainable=False, processor=processor) |
| |
|
| | model.eval() |
| | model = model.to(training_args.device, dtype=torch.bfloat16) |
| |
|
| | |
| | aop_cfg_env = get_env_aop_config() |
| |
|
| | |
| | ee_layer = int(os.environ.get("EE_LAYER", os.environ.get("AOP_LAYER", "12"))) |
| | feat_topk = int(os.environ.get("EE_FEAT_TOPK", "200")) |
| | force_no_aop = os.environ.get("DUMP_EXIT_NO_AOP", "1").strip().lower() in {"1", "true", "yes", "on"} |
| | print_master(f"[DUMP] EE_LAYER={ee_layer}, FEAT_TOPK={feat_topk}, NO_AOP={force_no_aop}") |
| |
|
| | |
| | with open(data_args.dataset_config, 'r', encoding='utf-8') as yf: |
| | dataset_configs = yaml.safe_load(yf) |
| |
|
| | for dataset_name, task_cfg in dataset_configs.items(): |
| | if dist.is_initialized(): |
| | dist.barrier() |
| | print_master(f"\n[DUMP] {dataset_name}") |
| |
|
| | |
| | if data_args.data_basedir: |
| | for key in ["image_root", "video_root", "frame_root", "clip_root", "data_path"]: |
| | if task_cfg.get(key): |
| | task_cfg[key] = os.path.join(data_args.data_basedir, task_cfg[key]) |
| |
|
| | full_qry, corpus = AutoEvalPairDataset.instantiate(model_args=model_args, data_args=data_args, **task_cfg) |
| | full_cand = generate_cand_dataset(full_qry, corpus) |
| |
|
| | qry_collator = MultimodalEvalDataCollator(processor, model_args, data_args, "qry") |
| | cand_collator = MultimodalEvalDataCollator(processor, model_args, data_args, "cand") |
| | qry_loader = DataLoader( |
| | full_qry, |
| | batch_size=training_args.per_device_eval_batch_size, |
| | collate_fn=qry_collator, |
| | num_workers=training_args.dataloader_num_workers, |
| | ) |
| | cand_loader = DataLoader( |
| | full_cand, |
| | batch_size=training_args.per_device_eval_batch_size, |
| | collate_fn=cand_collator, |
| | num_workers=training_args.dataloader_num_workers, |
| | ) |
| |
|
| | |
| | cand_mid_np, cand_last_np, cand_ids = encode_candidates_both_layers(model, cand_loader, training_args, mid_layer=ee_layer) |
| | cand_id2row = {str(cid): i for i, cid in enumerate(cand_ids)} |
| | device = training_args.device |
| | cand_mid_t = torch.from_numpy(cand_mid_np).to(device=device, dtype=torch.bfloat16) |
| | cand_last_t = None |
| |
|
| | |
| | feat_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_phaseA_features.jsonl") |
| | scaler_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_phaseA_scaler.json") |
| |
|
| | |
| | sum_feat = None |
| | sum2_feat = None |
| | n_feat = 0 |
| |
|
| | |
| | rank_global = task_cfg.get("eval_type", "global") == "global" |
| |
|
| | qid_global = 0 |
| | with open(feat_path, "w", encoding="utf-8") as fout: |
| | for inputs, infos in tqdm(qry_loader, desc=f"[DUMP] Qrys", disable=False): |
| | inputs = batch_to_device(inputs, device) |
| | B = inputs["input_ids"].size(0) |
| |
|
| | |
| | aop_cfg_cur = getattr(model.encoder, "aop_prune_config", None) |
| | orig_aop = None |
| | if force_no_aop and isinstance(aop_cfg_cur, dict): |
| | orig_aop = dict(aop_cfg_cur) |
| | aop_off = dict(aop_cfg_cur) |
| | aop_off["enabled"] = False |
| | setattr(model.encoder, "aop_prune_config", aop_off) |
| |
|
| | with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): |
| | out_mid = model.encoder( |
| | **inputs, |
| | return_dict=True, |
| | output_hidden_states=False, |
| | stop_at_layer=int(ee_layer), |
| | compute_lm_head=False, |
| | return_intermediate_state=True, |
| | ) |
| |
|
| | if orig_aop is not None: |
| | setattr(model.encoder, "aop_prune_config", orig_aop) |
| |
|
| | hs_mid = getattr(out_mid, "last_hidden_state", None) |
| | if hs_mid is None: |
| | hs_mid = out_mid.hidden_states[-1] |
| | am_mid = getattr(out_mid, "attention_mask", None) |
| | if am_mid is None: |
| | am_mid = inputs.get("attention_mask", None) |
| | if hasattr(am_mid, "device") and am_mid.device != hs_mid.device: |
| | am_mid = am_mid.to(hs_mid.device) |
| | reps_mid_t = model._pooling(hs_mid, am_mid).detach().to(device=device, dtype=torch.bfloat16) |
| |
|
| | |
| | if rank_global: |
| | feats_t = build_phaseA_features_global(reps_mid_t, cand_mid_t, am_mid, inputs["input_ids"], model.encoder.config, topk=feat_topk) |
| | else: |
| | rows_list = [] |
| | for b in range(B): |
| | cand_local = infos[b]["cand_names"] |
| | rows = [cand_id2row.get(str(cid), -1) for cid in cand_local] |
| | rows = [r for r in rows if r >= 0] |
| | rows_list.append(rows) |
| | feats_t = build_phaseA_features_local(reps_mid_t, cand_mid_t, am_mid, inputs["input_ids"], model.encoder.config, rows_list, topk=feat_topk) |
| | feats_np = feats_t.detach().float().cpu().numpy() |
| |
|
| | |
| | |
| | aop_cfg_cur = getattr(model.encoder, "aop_prune_config", None) |
| | orig_aop2 = None |
| | if force_no_aop and isinstance(aop_cfg_cur, dict): |
| | orig_aop2 = dict(aop_cfg_cur) |
| | aop_off2 = dict(aop_cfg_cur) |
| | aop_off2["enabled"] = False |
| | setattr(model.encoder, "aop_prune_config", aop_off2) |
| |
|
| | interm = getattr(out_mid, "intermediate_state", None) |
| | assert interm is not None, "intermediate_state missing; ensure return_intermediate_state=True" |
| | hs = interm["hidden_states"].detach() |
| | am = interm["attention_mask"].detach() |
| | pos = interm["position_ids"].detach() |
| | vm = interm.get("vision_mask", None) |
| | tm = interm.get("text_mask", None) |
| | next_layer = int(interm["next_layer_idx"]) |
| | resume_state = { |
| | "hidden_states": hs, |
| | "attention_mask": am, |
| | "position_ids": pos, |
| | "vision_mask": vm, |
| | "text_mask": tm, |
| | "next_layer_idx": next_layer, |
| | } |
| |
|
| | with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): |
| | out_last = model.encoder( |
| | return_dict=True, |
| | output_hidden_states=False, |
| | stop_at_layer=None, |
| | resume_state=resume_state, |
| | compute_lm_head=False, |
| | ) |
| |
|
| | if orig_aop2 is not None: |
| | setattr(model.encoder, "aop_prune_config", orig_aop2) |
| |
|
| | hs_last = getattr(out_last, "last_hidden_state", None) |
| | if hs_last is None: |
| | hs_last = out_last.hidden_states[-1] |
| | am_last = getattr(out_last, "attention_mask", None) |
| | if am_last is None: |
| | am_last = am |
| | if hasattr(am_last, "device") and am_last.device != hs_last.device: |
| | am_last = am_last.to(hs_last.device) |
| | reps_last_t = model._pooling(hs_last, am_last).detach().to(device=device, dtype=torch.bfloat16) |
| |
|
| | |
| | if rank_global: |
| | if cand_last_t is None: |
| | cand_last_t = torch.from_numpy(cand_last_np).to(device=device, dtype=torch.bfloat16) |
| | sim_mid = (reps_mid_t @ cand_mid_t.T).detach().float().cpu().numpy() |
| | sim_last = (reps_last_t @ cand_last_t.T).detach().float().cpu().numpy() |
| | y = compute_label_top1_equal_global(sim_mid, sim_last) |
| | else: |
| | y_list = [] |
| | for b in range(B): |
| | cand_local = infos[b]["cand_names"] |
| | rows = [cand_id2row.get(str(cid), -1) for cid in cand_local] |
| | rows = [r for r in rows if r >= 0] |
| | if len(rows) == 0: |
| | y_list.append(0) |
| | continue |
| | c_mid = cand_mid_t[rows] |
| | if cand_last_t is None: |
| | cand_last_t = torch.from_numpy(cand_last_np).to(device=device, dtype=torch.bfloat16) |
| | c_last = cand_last_t[rows] |
| | sv_mid = (reps_mid_t[b:b+1] @ c_mid.T)[0].detach().float().cpu().numpy() |
| | sv_last = (reps_last_t[b:b+1] @ c_last.T)[0].detach().float().cpu().numpy() |
| | y_list.append(int(int(sv_mid.argmax()) == int(sv_last.argmax()))) |
| | y = np.array(y_list, dtype=np.int32) |
| |
|
| | |
| | |
| | L_txt = feats_np[:, 5] |
| | L_vis = feats_np[:, 6] |
| | types = np.where((L_vis > 0) & (L_txt == 0), "I", np.where((L_txt > 0) & (L_vis == 0), "T", "IT")) |
| |
|
| | for b in range(B): |
| | row = { |
| | "dataset": dataset_name, |
| | "qid": int(qid_global + b), |
| | "type": str(types[b]), |
| | "feats": feats_np[b].tolist(), |
| | "y_exit": int(y[b]), |
| | } |
| | fout.write(json.dumps(row, ensure_ascii=False) + "\n") |
| |
|
| | |
| | if sum_feat is None: |
| | sum_feat = feats_np.sum(axis=0) |
| | sum2_feat = (feats_np**2).sum(axis=0) |
| | else: |
| | sum_feat += feats_np.sum(axis=0) |
| | sum2_feat += (feats_np**2).sum(axis=0) |
| | n_feat += feats_np.shape[0] |
| | qid_global += B |
| |
|
| | |
| | if n_feat > 0 and (not dist.is_initialized() or local_rank == 0): |
| | mean = (sum_feat / n_feat).tolist() |
| | var = (sum2_feat / n_feat - (sum_feat / n_feat) ** 2).tolist() |
| | std = [float(max(1e-6, math.sqrt(max(0.0, v)))) for v in var] |
| | with open(scaler_path, "w", encoding="utf-8") as f: |
| | json.dump({"mean": mean, "std": std, "in_dim": len(mean), "n": n_feat}, f, indent=2, ensure_ascii=False) |
| | print_master(f"[DUMP] {dataset_name}: features -> {feat_path}, scaler -> {scaler_path}, n={n_feat}") |
| |
|
| | if dist.is_initialized(): |
| | dist.barrier() |
| |
|
| |
|
| | |
| | from datasets import concatenate_datasets |
| |
|
| | if __name__ == "__main__": |
| | main() |