|
|
import os |
|
|
import re |
|
|
import ast |
|
|
import math |
|
|
import yaml |
|
|
import warnings |
|
|
from datetime import datetime |
|
|
from dataclasses import dataclass, field |
|
|
from collections import defaultdict |
|
|
from typing import Any, Callable, Optional, Union, Sized, Dict, Tuple, List, Literal, Type |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
from torch import nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
import datasets |
|
|
|
|
|
from PIL import Image |
|
|
|
|
|
from trl import ModelConfig, ScriptArguments, TrlParser, get_peft_config |
|
|
from trl.models import unwrap_model_for_generation |
|
|
|
|
|
from transformers import ( |
|
|
TrainingArguments, |
|
|
Trainer, |
|
|
GenerationConfig, |
|
|
) |
|
|
from transformers.modeling_utils import PreTrainedModel |
|
|
from transformers.utils import ( |
|
|
is_safetensors_available, |
|
|
is_peft_available |
|
|
) |
|
|
|
|
|
if is_safetensors_available(): |
|
|
import safetensors.torch |
|
|
from peft import PeftConfig, get_peft_model, PeftModel |
|
|
from accelerate.utils import is_peft_model, set_seed |
|
|
|
|
|
from qwen_vl_utils import process_vision_info |
|
|
|
|
|
from src.model.vlm_backbone.qwen2_5_vl_gp.process_gp import Qwen2_5_VL_GP_Processor |
|
|
|
|
|
from transformers.trainer import ( |
|
|
logger, |
|
|
TRAINING_ARGS_NAME, |
|
|
CONFIG_NAME, |
|
|
ADAPTER_WEIGHTS_NAME, |
|
|
ADAPTER_SAFE_WEIGHTS_NAME, |
|
|
WEIGHTS_NAME, |
|
|
WEIGHTS_INDEX_NAME, |
|
|
SAFE_WEIGHTS_NAME, |
|
|
SAFE_WEIGHTS_INDEX_NAME, |
|
|
FSDP_MODEL_NAME, |
|
|
) |
|
|
|
|
|
from src.model.vlm_backbone.qwen2_5_vl_gp.warppers import debug_calls |
|
|
from src.utils_gp import ( |
|
|
LLMClient, |
|
|
norm_bboxes, |
|
|
extract_one_bbox_from_str, |
|
|
cal_paired_ious, |
|
|
print_rank0 |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
QUERY_KEY = "query" |
|
|
IMG_PATH_KEY = "img_path" |
|
|
ANSWER_KEY = "answer" |
|
|
NORMED_BBOXES_KEY = "normed_bboxes" |
|
|
SCORE_FUNCS_KEY = "score_funcs" |
|
|
|
|
|
REMAIN_KEYS = [ |
|
|
QUERY_KEY, |
|
|
IMG_PATH_KEY, |
|
|
NORMED_BBOXES_KEY, |
|
|
ANSWER_KEY, |
|
|
SCORE_FUNCS_KEY, |
|
|
] |
|
|
|
|
|
MAPPER_REGISTRY: Dict[str, Callable] = {} |
|
|
FILTER_REGISTRY: Dict[str, Callable] = {} |
|
|
|
|
|
def register_mappers(): |
|
|
def wrapper(func): |
|
|
name = func.__name__.replace("_dataset_mapper", "") |
|
|
MAPPER_REGISTRY[name] = func |
|
|
return func |
|
|
return wrapper |
|
|
|
|
|
def register_filters(): |
|
|
def wrapper(func): |
|
|
name = func.__name__.replace("_dataset_filter", "") |
|
|
FILTER_REGISTRY[name] = func |
|
|
return func |
|
|
return wrapper |
|
|
|
|
|
|
|
|
@register_mappers() |
|
|
def cot_train_dataset_mapper(one_data, **kwargs): |
|
|
query = one_data['question'] |
|
|
if 'prompt' in kwargs: |
|
|
query = kwargs['prompt'].format(query) |
|
|
answer = one_data['answer'] |
|
|
image = one_data['image'] |
|
|
dataset = one_data['dataset'] |
|
|
img_path = os.path.join(kwargs['img_dir'], "cot", dataset, image) |
|
|
bboxes = one_data['bboxs'] |
|
|
return { |
|
|
QUERY_KEY: query, |
|
|
ANSWER_KEY: answer, |
|
|
IMG_PATH_KEY: img_path, |
|
|
NORMED_BBOXES_KEY: bboxes, |
|
|
SCORE_FUNCS_KEY: kwargs['score_funcs'] |
|
|
} |
|
|
|
|
|
|
|
|
@register_mappers() |
|
|
def cot_train_fullmask_dataset_mapper(one_data, **kwargs): |
|
|
query = one_data['question'] |
|
|
if 'prompt' in kwargs: |
|
|
query = kwargs['prompt'].format(query) |
|
|
answer = one_data['answer'] |
|
|
image = one_data['image'] |
|
|
dataset = one_data['dataset'] |
|
|
img_path = os.path.join(kwargs['img_dir'], "cot", dataset, image) |
|
|
normed_bboxes = [[0.0, 0.0, 1.0, 1.0]] |
|
|
return { |
|
|
QUERY_KEY: query, |
|
|
ANSWER_KEY: answer, |
|
|
IMG_PATH_KEY: img_path, |
|
|
NORMED_BBOXES_KEY: normed_bboxes, |
|
|
SCORE_FUNCS_KEY: kwargs['score_funcs'] |
|
|
} |
|
|
|
|
|
|
|
|
@register_mappers() |
|
|
def norm_bboxes_dataset_mapper(one_data, **kwargs): |
|
|
bboxes = one_data.pop(NORMED_BBOXES_KEY) |
|
|
if 'width' in one_data: |
|
|
width = one_data['width'] |
|
|
height = one_data['height'] |
|
|
else: |
|
|
img_path = one_data[IMG_PATH_KEY] |
|
|
img_pil = Image.open(img_path) |
|
|
width, height = img_pil.size |
|
|
img_pil.close() |
|
|
normed_bboxes = norm_bboxes(bboxes, height, width, bbox_type=kwargs['bbox_type']) |
|
|
one_data[NORMED_BBOXES_KEY] = normed_bboxes |
|
|
return one_data |
|
|
|
|
|
|
|
|
@register_filters() |
|
|
def image_exist_dataset_filter(one_data, **kwargs): |
|
|
img_path = one_data[IMG_PATH_KEY] |
|
|
try: |
|
|
img = Image.open(img_path) |
|
|
img.close() |
|
|
return True |
|
|
except (FileNotFoundError, OSError) as e: |
|
|
print_rank0(f"Image not found or invalid: {img_path}. Error: {e}") |
|
|
return False |
|
|
except Exception as e: |
|
|
print_rank0(f"Unexpected error while checking image: {img_path}. Error: {e}") |
|
|
return False |
|
|
|
|
|
@register_filters() |
|
|
def inputs_seq_length_dataset_filter(one_data, **kwargs): |
|
|
processor = kwargs['processor'] |
|
|
max_input_seq_length = kwargs.get('max_input_seq_length', None) |
|
|
max_input_remain_seq_length = kwargs.get('max_input_remain_seq_length', None) |
|
|
if max_input_seq_length is None and max_input_remain_seq_length is None: |
|
|
return True |
|
|
img_path = one_data[IMG_PATH_KEY] |
|
|
query = one_data[QUERY_KEY] |
|
|
normed_bboxes = [one_data[NORMED_BBOXES_KEY]] if max_input_remain_seq_length is not None else None |
|
|
messages = [[{"role": "user", "content": [{"type": "image", "image": img_path}, {"type": "text", "text": query}]}]] |
|
|
text = processor.apply_chat_template( |
|
|
messages, tokenize=False, add_generation_prompt=True |
|
|
) |
|
|
image_inputs, video_inputs = process_vision_info(messages) |
|
|
inputs = processor( |
|
|
text=text, |
|
|
images=image_inputs, |
|
|
videos=video_inputs, |
|
|
normed_bboxes=normed_bboxes, |
|
|
padding=True, |
|
|
return_tensors="pt", |
|
|
) |
|
|
seq_length = inputs.input_ids.shape[1] |
|
|
if max_input_seq_length is not None and seq_length > max_input_seq_length: |
|
|
return False |
|
|
|
|
|
if max_input_remain_seq_length is not None: |
|
|
ref_token_masks = inputs.ref_token_masks[0] |
|
|
reduced_num = ref_token_masks.numel() - ref_token_masks.sum().item() |
|
|
remain_seq_length = seq_length - reduced_num |
|
|
if remain_seq_length > max_input_remain_seq_length: |
|
|
return False |
|
|
return True |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
LOSS_REGISTRY: Dict[str, Type[nn.Module]] = {} |
|
|
|
|
|
def register_loss(loss_class): |
|
|
name = loss_class.__name__ |
|
|
if name in LOSS_REGISTRY: |
|
|
raise ValueError(f"Loss class '{name}' is already registered.") |
|
|
LOSS_REGISTRY[name] = loss_class |
|
|
return loss_class |
|
|
|
|
|
|
|
|
@register_loss |
|
|
class DiceLoss(nn.Module): |
|
|
def __init__(self, epsilon: float = 1e-6, **kwargs): |
|
|
super().__init__() |
|
|
self.epsilon = epsilon |
|
|
|
|
|
def forward(self, |
|
|
image_token_mask_logits: List[torch.Tensor], |
|
|
ref_token_masks: List[torch.Tensor] |
|
|
) -> torch.Tensor: |
|
|
if not isinstance(image_token_mask_logits, list) or not isinstance(ref_token_masks, list): |
|
|
raise TypeError("Inputs must be lists of tensors.") |
|
|
if len(image_token_mask_logits) != len(ref_token_masks): |
|
|
raise ValueError(f"Input lists must have the same length, but got " |
|
|
f"{len(image_token_mask_logits)} and {len(ref_token_masks)}") |
|
|
if len(image_token_mask_logits) == 0: |
|
|
return torch.tensor(0.0, device=image_token_mask_logits[0].device if image_token_mask_logits else None) |
|
|
|
|
|
batch_size = len(image_token_mask_logits) |
|
|
total_dice_loss = 0.0 |
|
|
|
|
|
for i in range(batch_size): |
|
|
pred_mask_1d = image_token_mask_logits[i].flatten().sigmoid() |
|
|
gt_mask_1d = ref_token_masks[i].flatten().to(pred_mask_1d.device, dtype=torch.float) |
|
|
intersection = (pred_mask_1d * gt_mask_1d).sum() |
|
|
pred_sum = pred_mask_1d.sum() |
|
|
gt_sum = gt_mask_1d.sum() |
|
|
dice_coefficient = (2.0 * intersection + self.epsilon) / (pred_sum + gt_sum + self.epsilon) |
|
|
total_dice_loss += (1.0 - dice_coefficient) |
|
|
|
|
|
return total_dice_loss / batch_size |
|
|
|
|
|
|
|
|
@register_loss |
|
|
class BCELoss(nn.Module): |
|
|
def ___init__(self, **kwargs): |
|
|
super(BCELoss, self).__init__() |
|
|
|
|
|
def forward(self, |
|
|
image_token_mask_logits: List[torch.Tensor], |
|
|
ref_token_masks: List[torch.Tensor] |
|
|
) -> torch.Tensor: |
|
|
|
|
|
batch_size = len(image_token_mask_logits) |
|
|
total_bce_loss = 0.0 |
|
|
for i in range(batch_size): |
|
|
pred_mask_1d = image_token_mask_logits[i].flatten() |
|
|
gt_mask_1d = ref_token_masks[i].flatten().to(pred_mask_1d.device) |
|
|
bce_loss = F.binary_cross_entropy_with_logits( |
|
|
pred_mask_1d.float(), |
|
|
gt_mask_1d.float(), |
|
|
) |
|
|
total_bce_loss += bce_loss |
|
|
return total_bce_loss / batch_size |
|
|
|
|
|
|
|
|
@register_loss |
|
|
class MaskLoss(nn.Module): |
|
|
def __init__(self, |
|
|
dice_weight: float = 0.5, |
|
|
bce_weight: float = 0.5, |
|
|
epsilon: float = 1e-6, |
|
|
**kwargs): |
|
|
super().__init__() |
|
|
self.dice_loss = DiceLoss(epsilon=epsilon) |
|
|
self.bce_loss = BCELoss() |
|
|
self.dice_weight = dice_weight |
|
|
self.bce_weight = bce_weight |
|
|
|
|
|
def forward(self, image_token_mask_logits: List[torch.Tensor], |
|
|
ref_token_masks: List[torch.Tensor] |
|
|
) -> torch.Tensor: |
|
|
dice_loss = self.dice_loss(image_token_mask_logits, ref_token_masks) |
|
|
bce_loss = self.bce_loss(image_token_mask_logits, ref_token_masks) |
|
|
return self.dice_weight * dice_loss + self.bce_weight * bce_loss |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
SCORE_REGISTRY: Dict[str, Callable] = {} |
|
|
|
|
|
def register_score(): |
|
|
def wrapper(func): |
|
|
name = func.__name__.replace("_score", "") |
|
|
SCORE_REGISTRY[name] = func |
|
|
return func |
|
|
return wrapper |
|
|
|
|
|
@register_score() |
|
|
def llm_score(query, completion, answer, args): |
|
|
""" |
|
|
YAML 里可能写了 'score_funcs: [llm]'。本工程不使用这些分数,返回 0 占位即可。 |
|
|
""" |
|
|
|
|
|
if isinstance(query, list): |
|
|
return [0.0] * len(query) |
|
|
return [0.0] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _resolve_rel_path(rel_path: str, base_dir: str) -> str: |
|
|
""" |
|
|
Resolve a relative path against base_dir; if not found, try parent dirs up to 4 levels. |
|
|
""" |
|
|
if os.path.isabs(rel_path): |
|
|
return rel_path |
|
|
candidates = [os.path.join(base_dir, rel_path)] |
|
|
parent = base_dir |
|
|
for _ in range(4): |
|
|
parent = os.path.dirname(parent) |
|
|
if not parent or parent in ("/", ""): |
|
|
break |
|
|
candidates.append(os.path.join(parent, rel_path)) |
|
|
for cand in candidates: |
|
|
if os.path.exists(cand): |
|
|
return cand |
|
|
return candidates[0] |
|
|
|
|
|
|
|
|
class GPDataset(torch.utils.data.Dataset): |
|
|
""" |
|
|
A PyTorch Dataset that loads and combines multiple datasets |
|
|
based on a YAML configuration file. It handles sampling |
|
|
and applies specified mapping functions. |
|
|
""" |
|
|
@classmethod |
|
|
def _load_config(cls, config_path: str) -> Dict[str, Any]: |
|
|
print_rank0(f"Loading configuration from: {config_path}") |
|
|
try: |
|
|
with open(config_path, 'r', encoding='utf-8') as f: |
|
|
conf = yaml.safe_load(f) |
|
|
if conf is None: |
|
|
raise ValueError("YAML config is empty.") |
|
|
|
|
|
base_dir = os.path.dirname(config_path) |
|
|
|
|
|
if 'datasets' not in conf: |
|
|
if 'train_dataset' in conf: |
|
|
ds_yaml = _resolve_rel_path(conf['train_dataset'], base_dir) |
|
|
print_rank0(f"Loading dataset config from: {ds_yaml}") |
|
|
with open(ds_yaml, 'r', encoding='utf-8') as f: |
|
|
conf2 = yaml.safe_load(f) |
|
|
if conf2 is None or 'datasets' not in conf2: |
|
|
raise ValueError(f"'{ds_yaml}' missing 'datasets' key.") |
|
|
conf = conf2 |
|
|
base_dir = os.path.dirname(ds_yaml) |
|
|
else: |
|
|
raise ValueError("YAML config is missing both 'datasets' and 'train_dataset' keys.") |
|
|
|
|
|
conf['__root_dir__'] = base_dir |
|
|
print_rank0("Configuration loaded successfully.") |
|
|
return conf |
|
|
except Exception as e: |
|
|
print_rank0(f"Failed to load config: {e}") |
|
|
raise |
|
|
|
|
|
@classmethod |
|
|
def _apply_sampling(cls, dataset: datasets.Dataset, strategy: Optional[str], seed: Optional[int] = None) -> datasets.Dataset: |
|
|
"""Applies sampling strategy to a dataset.""" |
|
|
if not strategy: |
|
|
print_rank0("No sampling strategy specified, using full dataset.") |
|
|
return dataset |
|
|
|
|
|
try: |
|
|
parts = strategy.split(':') |
|
|
if len(parts) != 2: |
|
|
raise ValueError(f"Invalid sampling strategy format: '{strategy}'. Expected 'type:value'.") |
|
|
strat_type, strat_value = parts[0].lower(), parts[1] |
|
|
num_samples = int(strat_value) |
|
|
total_size = len(dataset) |
|
|
if num_samples <= 0: |
|
|
raise ValueError(f"Sampling value must be positive, got: {num_samples} [{strategy}]") |
|
|
num_samples = min(num_samples, total_size) |
|
|
|
|
|
print_rank0(f"Applying sampling: {strategy} ({num_samples} samples) to dataset of size {total_size}") |
|
|
|
|
|
if strat_type == "first": |
|
|
return dataset.select(range(num_samples)) |
|
|
elif strat_type == "end": |
|
|
start_index = max(0, total_size - num_samples) |
|
|
return dataset.select(range(start_index, total_size)) |
|
|
elif strat_type == "random": |
|
|
shuffled_dataset = dataset.shuffle(seed=seed) |
|
|
return shuffled_dataset.select(range(num_samples)) |
|
|
else: |
|
|
print_rank0(f"Warning: Unknown sampling strategy type: '{strat_type}'. Using full dataset.") |
|
|
return dataset |
|
|
except ValueError as e: |
|
|
print_rank0(f"Error parsing sampling strategy '{strategy}': {e}. Using full dataset.") |
|
|
return dataset |
|
|
except Exception as e: |
|
|
print_rank0(f"An unexpected error occurred during sampling: {e}. Using full dataset.") |
|
|
return dataset |
|
|
|
|
|
@classmethod |
|
|
def _all_processed_datasets(cls, config, processor, args): |
|
|
root_dir = config.get('__root_dir__', os.getcwd()) |
|
|
all_processed_datasets: Dict[str, datasets.Dataset] = {} |
|
|
for i, dataset_config in enumerate(config['datasets']): |
|
|
print_rank0(f"\nProcessing dataset entry {i+1}/{len(config['datasets'])}...") |
|
|
json_path = dataset_config.get('json_path') |
|
|
if not json_path: |
|
|
print_rank0(f"Warning: Skipping dataset entry {i+1} due to missing 'json_path'.") |
|
|
continue |
|
|
json_path = _resolve_rel_path(json_path, root_dir) |
|
|
|
|
|
base_name = '.'.join(os.path.basename(json_path).split('.')[:-1]) |
|
|
dataset_name = dataset_config.get('dataset_name', base_name) |
|
|
|
|
|
sampling_strategy = dataset_config.get('sampling_strategy', None) |
|
|
sampling_seed = dataset_config['sampling_seed'] if 'sampling_seed' in dataset_config else getattr(args, 'sampling_seed', 42) |
|
|
|
|
|
mapper_name = dataset_config.get('mapper') |
|
|
bbox_type = dataset_config.get('bbox_type') |
|
|
|
|
|
|
|
|
if 'img_dir' in dataset_config: |
|
|
img_dir = _resolve_rel_path(dataset_config['img_dir'], root_dir) |
|
|
else: |
|
|
img_dir = getattr(args, 'img_dir', None) |
|
|
if img_dir is not None: |
|
|
img_dir = _resolve_rel_path(img_dir, root_dir) |
|
|
|
|
|
additional_mappers = dataset_config.get('additional_mappers', []) |
|
|
score_funcs = dataset_config.get('score_funcs', []) |
|
|
prompt = dataset_config.get('prompt', None) |
|
|
|
|
|
max_input_seq_length = dataset_config['max_input_seq_length'] if 'max_input_seq_length' in dataset_config else getattr(args, 'max_input_seq_length', None) |
|
|
max_input_remain_seq_length = dataset_config['max_input_remain_seq_length'] if 'max_input_remain_seq_length' in dataset_config else getattr(args, 'max_input_remain_seq_length', None) |
|
|
|
|
|
|
|
|
if score_funcs: |
|
|
filtered = [] |
|
|
for sf in score_funcs: |
|
|
if sf in SCORE_REGISTRY: |
|
|
filtered.append(sf) |
|
|
else: |
|
|
print_rank0(f"Warning: Score function '{sf}' not registered. Will ignore.") |
|
|
score_funcs = filtered |
|
|
|
|
|
try: |
|
|
print_rank0(f"Loading raw data from: {json_path}") |
|
|
raw_dataset = datasets.load_dataset('json', data_files=json_path, split='train') |
|
|
print_rank0(f"Loaded {len(raw_dataset)} examples raw.") |
|
|
|
|
|
sampled_dataset = cls._apply_sampling(raw_dataset, sampling_strategy, sampling_seed) |
|
|
if len(sampled_dataset) == 0: |
|
|
print_rank0("Dataset is empty after sampling, skipping.") |
|
|
continue |
|
|
print_rank0(f"Dataset size after sampling: {len(sampled_dataset)}") |
|
|
|
|
|
mapper_func = MAPPER_REGISTRY[mapper_name] |
|
|
print_rank0(f"Applying mapper: '{mapper_name}'") |
|
|
mapper_kwargs = { |
|
|
'img_dir': img_dir, |
|
|
'score_funcs': score_funcs, |
|
|
} |
|
|
if prompt is not None: |
|
|
mapper_kwargs['prompt'] = prompt |
|
|
print_rank0(f"Mapper arguments: {mapper_kwargs}") |
|
|
|
|
|
processed_dataset = sampled_dataset.map( |
|
|
mapper_func, |
|
|
num_proc=8, |
|
|
fn_kwargs=mapper_kwargs, |
|
|
) |
|
|
|
|
|
processed_dataset = processed_dataset.remove_columns( |
|
|
[col for col in processed_dataset.column_names if col not in REMAIN_KEYS] |
|
|
) |
|
|
|
|
|
print_rank0("Applying dataset filter: 'image_exist_dataset_filter'") |
|
|
processed_dataset = processed_dataset.filter( |
|
|
image_exist_dataset_filter, |
|
|
num_proc=8, |
|
|
fn_kwargs={} |
|
|
) |
|
|
print_rank0(f"Processed dataset size after image_exist_dataset_filter: {len(processed_dataset)}") |
|
|
|
|
|
if max_input_seq_length is not None or max_input_remain_seq_length is not None: |
|
|
processed_dataset = processed_dataset.filter( |
|
|
inputs_seq_length_dataset_filter, |
|
|
num_proc=8, |
|
|
fn_kwargs={ |
|
|
'processor': processor, |
|
|
'max_input_seq_length': max_input_seq_length, |
|
|
'max_input_remain_seq_length': max_input_remain_seq_length, |
|
|
} |
|
|
) |
|
|
print_rank0(f"Processed dataset size after inputs_seq_length_dataset_filter: {len(processed_dataset)}") |
|
|
|
|
|
for additional_mapper in additional_mappers: |
|
|
mapper_func = MAPPER_REGISTRY[additional_mapper] |
|
|
print_rank0(f"Applying additional mapper: '{additional_mapper}'") |
|
|
processed_dataset = processed_dataset.map( |
|
|
mapper_func, |
|
|
num_proc=8, |
|
|
fn_kwargs={ |
|
|
'bbox_type': bbox_type, |
|
|
} |
|
|
) |
|
|
print_rank0(f"Processed dataset size: {len(processed_dataset)}") |
|
|
if len(processed_dataset) == 0: |
|
|
print_rank0(f"Warning: Processed dataset {dataset_name} is empty after mapping. Skipping.") |
|
|
continue |
|
|
|
|
|
if dataset_name in all_processed_datasets: |
|
|
dataset_name_with_uuid = f"{dataset_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}" |
|
|
print_rank0(f"Warning: Dataset name '{dataset_name}' already exists. Renaming to '{dataset_name_with_uuid}'") |
|
|
all_processed_datasets[dataset_name_with_uuid] = processed_dataset |
|
|
else: |
|
|
all_processed_datasets[dataset_name] = processed_dataset |
|
|
|
|
|
except FileNotFoundError: |
|
|
print_rank0(f"Error: Data file not found for dataset entry {i+1}: {json_path}. Skipping.") |
|
|
except Exception as e: |
|
|
print_rank0(f"Error processing dataset entry {i+1} ({json_path}): {e}. Skipping.") |
|
|
|
|
|
return all_processed_datasets |
|
|
|
|
|
|
|
|
def __init__(self, config_path: str, processor: Qwen2_5_VL_GP_Processor, script_args: Optional[Any] = None): |
|
|
""" |
|
|
Initializes the GPDataset. |
|
|
|
|
|
Args: |
|
|
config_path (str): Path to the YAML configuration file. |
|
|
processor (Qwen2_5_VL_GP_Processor): Processor for handling text and vision data. |
|
|
script_args (Any, optional): Additional arguments passed from the script |
|
|
(e.g., training args, could contain seed). Defaults to None. |
|
|
""" |
|
|
super().__init__() |
|
|
self.args = script_args |
|
|
self.config = self._load_config(config_path) |
|
|
self.processor = processor |
|
|
all_processed_datasets = self._all_processed_datasets(self.config, self.processor, self.args) |
|
|
if all_processed_datasets: |
|
|
print_rank0(f"\nConcatenating {len(all_processed_datasets)} processed dataset(s)...") |
|
|
self.final_dataset = datasets.concatenate_datasets(list(all_processed_datasets.values())) |
|
|
if len(self.final_dataset) == 0: |
|
|
raise ValueError("Final dataset is empty after concatenation.") |
|
|
print_rank0(f"Final combined dataset size: {len(self.final_dataset)}") |
|
|
print_rank0(f"Final dataset features: {self.final_dataset.features}") |
|
|
else: |
|
|
raise ValueError("No datasets were successfully processed. Please check your configuration.") |
|
|
self.final_dataset = None |
|
|
|
|
|
def __len__(self) -> int: |
|
|
return len(self.final_dataset) if self.final_dataset else 0 |
|
|
|
|
|
def __getitem__(self, index: int) -> Dict[str, Any]: |
|
|
if self.final_dataset is None: |
|
|
raise IndexError("Dataset is not initialized or is empty.") |
|
|
if not 0 <= index < len(self.final_dataset): |
|
|
raise IndexError(f"Index {index} out of bounds for dataset of size {len(self.final_dataset)}") |
|
|
return self.final_dataset[index] |
|
|
|
|
|
|
|
|
@classmethod |
|
|
def get_processed_dataset_dict(cls, config_path: str, processor: Qwen2_5_VL_GP_Processor, script_args: Optional[Any] = None) -> Dict[str, datasets.Dataset]: |
|
|
config = cls._load_config(config_path) |
|
|
all_processed_datasets = cls._all_processed_datasets(config, processor, script_args) |
|
|
return all_processed_datasets |
|
|
|
|
|
|
|
|
|
|
|
class GPCollator: |
|
|
def __init__(self, processor, is_sft): |
|
|
self.processor = processor |
|
|
self.is_sft = is_sft |
|
|
self.im_start_id = self.processor.tokenizer.encode("<|im_start|>")[0] |
|
|
|
|
|
def _prepare_labels_from_input_ids(self, input_ids): |
|
|
B, L = input_ids.shape |
|
|
labels = input_ids.clone() |
|
|
mask = input_ids == self.im_start_id |
|
|
flipped_mask = mask.flip(dims=(1,)) |
|
|
first_idx_in_flipped = torch.argmax(flipped_mask.int(), dim=1) |
|
|
last_pos = (L - 1) - first_idx_in_flipped |
|
|
mask_until_idx = last_pos + 3 |
|
|
mask_until_idx = torch.clamp(mask_until_idx, max=L) |
|
|
arange_l = torch.arange(L, device=input_ids.device).expand(B, -1) |
|
|
modification_mask = arange_l < mask_until_idx.unsqueeze(1) |
|
|
labels[modification_mask] = -100 |
|
|
return labels |
|
|
|
|
|
def __call__(self, features): |
|
|
messages = [] |
|
|
normed_bboxes = [] |
|
|
answers = [] |
|
|
querys = [] |
|
|
score_funcs = [] |
|
|
for feature in features: |
|
|
query = feature[QUERY_KEY] |
|
|
answer = feature[ANSWER_KEY] |
|
|
img_path = feature[IMG_PATH_KEY] |
|
|
if self.is_sft: |
|
|
messages.append([{"role": "user", "content": [{"type": "image", "image": img_path}, {"type": "text", "text": query}]}, {"role": "assistant", "content": [{"type": "text", "text": answer}]}]) |
|
|
else: |
|
|
messages.append([{"role": "user", "content": [{"type": "image", "image": img_path}, {"type": "text", "text": query}]}]) |
|
|
normed_bboxes.append(feature[NORMED_BBOXES_KEY]) |
|
|
querys.append(query) |
|
|
answers.append(answer) |
|
|
score_funcs.append(feature[SCORE_FUNCS_KEY]) |
|
|
|
|
|
text = self.processor.apply_chat_template( |
|
|
messages, tokenize=False, add_generation_prompt=(not self.is_sft) |
|
|
) |
|
|
image_inputs, video_inputs = process_vision_info(messages) |
|
|
inputs = self.processor( |
|
|
text=text, |
|
|
normed_bboxes=normed_bboxes, |
|
|
images=image_inputs, |
|
|
videos=video_inputs, |
|
|
padding=True, |
|
|
return_tensors="pt", |
|
|
) |
|
|
|
|
|
if self.is_sft: |
|
|
labels = self._prepare_labels_from_input_ids(inputs.input_ids) |
|
|
inputs["labels"] = labels |
|
|
|
|
|
inputs[QUERY_KEY] = querys |
|
|
inputs[ANSWER_KEY] = answers |
|
|
inputs[SCORE_FUNCS_KEY] = score_funcs |
|
|
return inputs |