| from typing import List, Dict, Optional, Tuple |
| from PIL import Image, ImageOps, ImageDraw, ImageFont |
| import torch |
| import torch.nn as nn |
| from torchvision import transforms |
| from transformers import TextStreamer |
| from transformers.tokenization_utils import PreTrainedTokenizer as T |
| from abc import ABC |
| import re |
| import numpy as np |
|
|
|
|
| def load_image(image_path): |
| try: |
| image = Image.open(image_path) |
| corrected_image = ImageOps.exif_transpose(image) |
|
|
| return corrected_image |
| |
| except Exception as e: |
| print(f"error: {e}") |
|
|
| return None |
|
|
|
|
| def re_match(text): |
| pattern = r'(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)' |
| matches = re.findall(pattern, text, re.DOTALL) |
|
|
| |
| |
|
|
| mathes_image = [] |
| mathes_other = [] |
| for a_match in matches: |
| if '<|ref|>image<|/ref|>' in a_match[0]: |
| mathes_image.append(a_match[0]) |
| else: |
| mathes_other.append(a_match[0]) |
| return matches, mathes_image, mathes_other |
|
|
|
|
| def extract_coordinates_and_label(ref_text, image_width, image_height): |
|
|
| try: |
| label_type = ref_text[1] |
| cor_list = eval(ref_text[2]) |
| except Exception as e: |
| print(e) |
| return None |
|
|
| return (label_type, cor_list) |
|
|
|
|
| def draw_bounding_boxes(image, refs, ouput_path): |
|
|
| image_width, image_height = image.size |
| |
| img_draw = image.copy() |
| draw = ImageDraw.Draw(img_draw) |
|
|
| overlay = Image.new('RGBA', img_draw.size, (0, 0, 0, 0)) |
| draw2 = ImageDraw.Draw(overlay) |
| |
| font = ImageFont.load_default() |
|
|
| img_idx = 0 |
| |
| for i, ref in enumerate(refs): |
| try: |
| result = extract_coordinates_and_label(ref, image_width, image_height) |
| if result: |
| label_type, points_list = result |
| |
| color = (np.random.randint(0, 200), np.random.randint(0, 200), np.random.randint(0, 255)) |
|
|
| color_a = color + (20, ) |
| for points in points_list: |
| x1, y1, x2, y2 = points |
|
|
| x1 = int(x1 / 999 * image_width) |
| y1 = int(y1 / 999 * image_height) |
|
|
| x2 = int(x2 / 999 * image_width) |
| y2 = int(y2 / 999 * image_height) |
|
|
| if label_type == 'image': |
| try: |
| cropped = image.crop((x1, y1, x2, y2)) |
| cropped.save(f"{ouput_path}/images/{img_idx}.jpg") |
| except Exception as e: |
| print(e) |
| pass |
| img_idx += 1 |
| |
| try: |
| if label_type == 'title': |
| draw.rectangle([x1, y1, x2, y2], outline=color, width=4) |
| draw2.rectangle([x1, y1, x2, y2], fill=color_a, outline=(0, 0, 0, 0), width=1) |
| else: |
| draw.rectangle([x1, y1, x2, y2], outline=color, width=2) |
| draw2.rectangle([x1, y1, x2, y2], fill=color_a, outline=(0, 0, 0, 0), width=1) |
| text_x = x1 |
| text_y = max(0, y1 - 15) |
| |
| |
| text_bbox = draw.textbbox((0, 0), label_type, font=font) |
| text_width = text_bbox[2] - text_bbox[0] |
| text_height = text_bbox[3] - text_bbox[1] |
| draw.rectangle([text_x, text_y, text_x + text_width, text_y + text_height], |
| fill=(255, 255, 255, 30)) |
| |
| draw.text((text_x, text_y), label_type, font=font, fill=color) |
| except: |
| pass |
| except: |
| continue |
| img_draw.paste(overlay, (0, 0), overlay) |
| return img_draw |
|
|
|
|
| def process_image_with_refs(image, ref_texts, output_path): |
|
|
| result_image = draw_bounding_boxes(image, ref_texts, output_path) |
| |
| return result_image |
|
|
|
|
| def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): |
| best_ratio_diff = float('inf') |
| best_ratio = (1, 1) |
| area = width * height |
| for ratio in target_ratios: |
| target_aspect_ratio = ratio[0] / ratio[1] |
| ratio_diff = abs(aspect_ratio - target_aspect_ratio) |
| if ratio_diff < best_ratio_diff: |
| best_ratio_diff = ratio_diff |
| best_ratio = ratio |
| elif ratio_diff == best_ratio_diff: |
| if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: |
| best_ratio = ratio |
|
|
| |
| return best_ratio |
|
|
|
|
| def dynamic_preprocess(image, min_num=2, max_num=9, image_size=640, use_thumbnail=False): |
| orig_width, orig_height = image.size |
| aspect_ratio = orig_width / orig_height |
|
|
| |
| target_ratios = set((i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if i * j <= max_num and i * j >= min_num) |
| |
| target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) |
| |
|
|
| target_aspect_ratio = find_closest_aspect_ratio( |
| aspect_ratio, |
| target_ratios, |
| orig_width, |
| orig_height, |
| image_size |
| ) |
| |
|
|
| |
| target_width = image_size * target_aspect_ratio[0] |
| target_height = image_size * target_aspect_ratio[1] |
| blocks = target_aspect_ratio[0] * target_aspect_ratio[1] |
|
|
| |
| resized_img = image.resize((target_width, target_height)) |
| processed_images = [] |
| for i in range(blocks): |
| box = ( |
| (i % (target_width // image_size)) * image_size, |
| (i // (target_width // image_size)) * image_size, |
| ((i % (target_width // image_size)) + 1) * image_size, |
| ((i // (target_width // image_size)) + 1) * image_size |
| ) |
| |
| split_img = resized_img.crop(box) |
| processed_images.append(split_img) |
|
|
| assert len(processed_images) == blocks |
| |
|
|
| if use_thumbnail and len(processed_images) != 1: |
| thumbnail_img = image.resize((image_size, image_size)) |
| processed_images.append(thumbnail_img) |
| return processed_images, target_aspect_ratio |
|
|
|
|
| def normalize_transform(mean, std): |
| if mean is None and std is None: |
| transform = None |
| elif mean is None and std is not None: |
| mean = [0.] * len(std) |
| transform = transforms.Normalize(mean=mean, std=std) |
| elif mean is not None and std is None: |
| std = [1.] * len(mean) |
| transform = transforms.Normalize(mean=mean, std=std) |
| else: |
| transform = transforms.Normalize(mean=mean, std=std) |
|
|
| return transform |
|
|
| def format_messages( |
| tokenizer: T, |
| conversations: List[Dict[str, str]], |
| system_prompt: str = "", |
| ): |
| if system_prompt is not None and system_prompt != "": |
| sys_prompt = { |
| "role": "system", |
| "content": system_prompt, |
| } |
| conversations = [sys_prompt] + conversations |
|
|
| sft_prompt = tokenizer.apply_chat_template( |
| conversations, |
| ) |
|
|
| return sft_prompt |
|
|
|
|
| def text_encode(tokenizer, text: str, bos: bool = True, eos: bool = False): |
| """ |
| Encode text with optional BOS/EOS tokens. |
| |
| Note: Qwen2VL tokenizer has bos_token_id=None, so we skip BOS for Qwen. |
| The chat template handles special tokens automatically. |
| """ |
| t = tokenizer.encode(text, add_special_tokens=False) |
| bos_id = tokenizer.bos_token_id |
| eos_id = tokenizer.eos_token_id |
|
|
| |
| if bos and bos_id is not None: |
| t = [bos_id] + t |
| |
| |
| if eos and eos_id is not None: |
| t = t + [eos_id] |
|
|
| return t |
|
|
| def load_pil_images(conversations: List[Dict[str, str]]) -> List[Image.Image]: |
| pil_images = [] |
|
|
| for message in conversations: |
| pil_image = None |
| |
| if message["role"].lower() == "user": |
| if isinstance(message["content"], List): |
| for d in message["content"]: |
| if d.get("type", "") == "image": |
| |
| image_path = d.get("image") or d.get("data", "") |
| pil_image = load_image(image_path) |
|
|
| elif isinstance(message["content"], Dict): |
| if message["content"].get("type", "") == "image": |
| |
| image_path = message["content"].get("image") or message["content"].get("data", "") |
| pil_image = load_image(image_path) |
|
|
| if pil_image is not None: |
| pil_images.append(pil_image) |
|
|
| return pil_images |
|
|
|
|
| class BaseTransform(ABC): |
|
|
| def set_rng(self, *args, **kwargs): |
| pass |
|
|
| def __call__(self, *args, **kwargs) -> torch.Tensor: |
| pass |
|
|
| @property |
| def default_shape(self): |
| raise NotImplementedError |
|
|
|
|
| class BasicImageTransform(BaseTransform): |
| def __init__( |
| self, |
| mean: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5), |
| std: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5), |
| normalize: bool = True |
| ): |
| self.mean = mean |
| self.std = std |
| |
| transform_pipelines = [ |
| transforms.ToTensor() |
| ] |
|
|
| normalize = normalize_transform(mean, std) if normalize else nn.Identity() |
| if normalize is not None: |
| transform_pipelines.append(normalize) |
|
|
| self.transform = transforms.Compose(transform_pipelines) |
| |
| def __call__(self, x): |
| x = self.transform(x) |
| return x |
|
|
| class NoEOSTextStreamer(TextStreamer): |
| |
| def on_finalized_text(self, text: str, stream_end: bool = False): |
| eos_text = self.tokenizer.decode([self.tokenizer.eos_token_id], skip_special_tokens=False) |
| text = text.replace(eos_text, "\n") |
| print(text, flush=True, end="") |
|
|
|
|
|
|
|
|
| |
|
|
| import torch |
| import math |
| from dataclasses import dataclass |
| from typing import Dict, List, Any, Tuple |
| from PIL import Image, ImageOps |
| from torch.nn.utils.rnn import pad_sequence |
| import io |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| @dataclass |
| class DeepQwenDataCollator: |
| """ |
| Data collator for DeepQwen model using Qwen2VL tokenizer. |
| |
| This collator processes images using DeepSeek OCR's dynamic cropping algorithm |
| while maintaining compatibility with Qwen2VL's tokenization format. |
| |
| Key token mappings (Qwen2VL): |
| - image_token: <|image_pad|> (id=151655) |
| - vision_start: <|vision_start|> (id=151652) |
| - vision_end: <|vision_end|> (id=151653) |
| - eos_token: <|im_end|> (id=151645) |
| - NO bos_token (bos_token_id is None) |
| |
| Args: |
| tokenizer: Qwen2VL Tokenizer |
| model: Model |
| image_size: Size for image patches (default: 640) |
| base_size: Size for global view (default: 1024) |
| crop_mode: Whether to use dynamic cropping for large images |
| train_on_responses_only: If True, only train on assistant responses (mask user prompts) |
| """ |
| tokenizer: T |
| model: Any |
| image_size: int = 640 |
| base_size: int = 1024 |
| crop_mode: bool = True |
| train_on_responses_only: bool = True |
|
|
| def __init__( |
| self, |
| tokenizer, |
| model, |
| image_size: int = 640, |
| base_size: int = 1024, |
| crop_mode: bool = True, |
| train_on_responses_only: bool = True, |
| max_length: int = None, |
| ): |
| self.tokenizer = tokenizer |
| self.model = model |
| self.image_size = image_size |
| self.base_size = base_size |
| self.crop_mode = crop_mode |
| self.dtype = model.dtype |
| self.train_on_responses_only = train_on_responses_only |
| self.max_length = max_length |
| |
| |
| |
| self.image_token_id = getattr(tokenizer, 'image_token_id', None) |
| if self.image_token_id is None: |
| |
| self.image_token_id = 151655 |
| |
| self.image_token = tokenizer.decode([self.image_token_id], skip_special_tokens=False) |
| |
| |
| self.vision_start_token_id = getattr(tokenizer, 'vision_start_token_id', 151652) |
| self.vision_end_token_id = getattr(tokenizer, 'vision_end_token_id', 151653) |
|
|
| self.image_transform = BasicImageTransform( |
| mean=(0.5, 0.5, 0.5), |
| std=(0.5, 0.5, 0.5), |
| normalize=True |
| ) |
| self.patch_size = 16 |
| self.downsample_ratio = 4 |
|
|
| |
| |
| self.bos_id = tokenizer.bos_token_id |
| self.eos_id = tokenizer.eos_token_id |
| self.pad_token_id = tokenizer.pad_token_id |
|
|
| def deserialize_image(self, image_data) -> Image.Image: |
| """Convert image data (bytes dict, PIL Image, or file path) to PIL Image in RGB mode""" |
| if isinstance(image_data, Image.Image): |
| return image_data.convert("RGB") |
| elif isinstance(image_data, str): |
| |
| image = load_image(image_data) |
| if image is None: |
| raise ValueError(f"Failed to load image from path: {image_data}") |
| return image.convert("RGB") |
| elif isinstance(image_data, dict) and 'bytes' in image_data: |
| image_bytes = image_data['bytes'] |
| image = Image.open(io.BytesIO(image_bytes)) |
| return image.convert("RGB") |
| else: |
| raise ValueError(f"Unsupported image format: {type(image_data)}") |
|
|
| def calculate_image_token_count(self, image: Image.Image, crop_ratio: Tuple[int, int]) -> int: |
| """Calculate the number of tokens this image will generate""" |
| num_queries = math.ceil((self.image_size // self.patch_size) / self.downsample_ratio) |
| num_queries_base = math.ceil((self.base_size // self.patch_size) / self.downsample_ratio) |
|
|
| width_crop_num, height_crop_num = crop_ratio |
|
|
| if self.crop_mode: |
| img_tokens = num_queries_base * num_queries_base + 1 |
| if width_crop_num > 1 or height_crop_num > 1: |
| img_tokens += (num_queries * width_crop_num + 1) * (num_queries * height_crop_num) |
| else: |
| img_tokens = num_queries * num_queries + 1 |
|
|
| return img_tokens |
|
|
| def process_image(self, image: Image.Image) -> Tuple[List, List, List, List, Tuple[int, int]]: |
| """ |
| Process a single image based on crop_mode and size thresholds |
| |
| Returns: |
| Tuple of (images_list, images_crop_list, images_spatial_crop, tokenized_image, crop_ratio) |
| """ |
| images_list = [] |
| images_crop_list = [] |
| images_spatial_crop = [] |
|
|
| if self.crop_mode: |
| |
| if image.size[0] <= 640 and image.size[1] <= 640: |
| crop_ratio = (1, 1) |
| images_crop_raw = [] |
| else: |
| images_crop_raw, crop_ratio = dynamic_preprocess( |
| image, min_num=2, max_num=9, |
| image_size=self.image_size, use_thumbnail=False |
| ) |
|
|
| |
| global_view = ImageOps.pad( |
| image, (self.base_size, self.base_size), |
| color=tuple(int(x * 255) for x in self.image_transform.mean) |
| ) |
| images_list.append(self.image_transform(global_view).to(self.dtype)) |
|
|
| width_crop_num, height_crop_num = crop_ratio |
| images_spatial_crop.append([width_crop_num, height_crop_num]) |
|
|
| |
| if width_crop_num > 1 or height_crop_num > 1: |
| for crop_img in images_crop_raw: |
| images_crop_list.append( |
| self.image_transform(crop_img).to(self.dtype) |
| ) |
|
|
| |
| num_queries = math.ceil((self.image_size // self.patch_size) / self.downsample_ratio) |
| num_queries_base = math.ceil((self.base_size // self.patch_size) / self.downsample_ratio) |
|
|
| tokenized_image = ([self.image_token_id] * num_queries_base + [self.image_token_id]) * num_queries_base |
| tokenized_image += [self.image_token_id] |
|
|
| if width_crop_num > 1 or height_crop_num > 1: |
| tokenized_image += ([self.image_token_id] * (num_queries * width_crop_num) + [self.image_token_id]) * ( |
| num_queries * height_crop_num) |
|
|
| else: |
| crop_ratio = (1, 1) |
| images_spatial_crop.append([1, 1]) |
|
|
| |
| if self.base_size <= 640: |
| resized_image = image.resize((self.base_size, self.base_size), Image.LANCZOS) |
| images_list.append(self.image_transform(resized_image).to(self.dtype)) |
| else: |
| global_view = ImageOps.pad( |
| image, (self.base_size, self.base_size), |
| color=tuple(int(x * 255) for x in self.image_transform.mean) |
| ) |
| images_list.append(self.image_transform(global_view).to(self.dtype)) |
|
|
| num_queries = math.ceil((self.base_size // self.patch_size) / self.downsample_ratio) |
| tokenized_image = ([self.image_token_id] * num_queries + [self.image_token_id]) * num_queries |
| tokenized_image += [self.image_token_id] |
|
|
| return images_list, images_crop_list, images_spatial_crop, tokenized_image, crop_ratio |
|
|
| def process_single_sample(self, messages: List[Dict]) -> Dict[str, Any]: |
| """ |
| Process a single conversation into model inputs. |
| |
| Expected message format (Qwen2.5-VL native style): |
| [ |
| { |
| "role": "user", |
| "content": [ |
| {"type": "image", "image": <PIL.Image or path or bytes>}, |
| {"type": "text", "text": "Describe this image."} |
| ] |
| }, |
| { |
| "role": "assistant", |
| "content": [{"type": "text", "text": "This is a description..."}] |
| } |
| ] |
| |
| Also supports string content for backward compatibility. |
| """ |
|
|
| |
| tokenized_str = [] |
| images_seq_mask = [] |
| images_list, images_crop_list, images_spatial_crop = [], [], [] |
|
|
| prompt_token_count = -1 |
| assistant_started = False |
|
|
| |
|
|
| for message in messages: |
| role = message["role"].lower() |
| content = message["content"] |
|
|
| |
| if role == "assistant": |
| if not assistant_started: |
| |
| |
| prompt_token_count = len(tokenized_str) |
| assistant_started = True |
|
|
| |
| if isinstance(content, list): |
| |
| content_parts = [] |
| |
| for item in content: |
| item_type = item.get("type", "") |
| |
| if item_type == "image": |
| |
| image_data = item.get("image") or item.get("data") |
| if image_data is not None: |
| pil_image = self.deserialize_image(image_data) |
| |
| |
| img_list, crop_list, spatial_crop, tok_img, _ = self.process_image(pil_image) |
| |
| images_list.extend(img_list) |
| images_crop_list.extend(crop_list) |
| images_spatial_crop.extend(spatial_crop) |
| |
| |
| tokenized_str.extend(tok_img) |
| images_seq_mask.extend([True] * len(tok_img)) |
| |
| elif item_type == "text": |
| text = item.get("text", "") |
| |
| |
| if role == "assistant" and item == content[-1]: |
| if self.tokenizer.eos_token: |
| text = f"{text.strip()}{self.tokenizer.eos_token}" |
| |
| |
| tokenized_text = text_encode(self.tokenizer, text, bos=False, eos=False) |
| tokenized_str.extend(tokenized_text) |
| images_seq_mask.extend([False] * len(tokenized_text)) |
| |
| else: |
| |
| text_content = content |
| |
| |
| if role == "assistant" and self.tokenizer.eos_token: |
| text_content = f"{text_content.strip()}{self.tokenizer.eos_token}" |
| |
| |
| tokenized_text = text_encode(self.tokenizer, text_content, bos=False, eos=False) |
| tokenized_str.extend(tokenized_text) |
| images_seq_mask.extend([False] * len(tokenized_text)) |
|
|
| |
| |
| |
| if not assistant_started: |
| print("Warning: No assistant message found in sample. Masking all tokens.") |
| prompt_token_count = len(tokenized_str) |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
| images_ori = torch.stack(images_list, dim=0) |
| images_spatial_crop_tensor = torch.tensor(images_spatial_crop, dtype=torch.long) |
|
|
| if images_crop_list: |
| images_crop = torch.stack(images_crop_list, dim=0) |
| else: |
| images_crop = torch.zeros((1, 3, self.base_size, self.base_size), dtype=self.dtype) |
|
|
| return { |
| "input_ids": torch.tensor(tokenized_str, dtype=torch.long), |
| "images_seq_mask": torch.tensor(images_seq_mask, dtype=torch.bool), |
| "images_ori": images_ori, |
| "images_crop": images_crop, |
| "images_spatial_crop": images_spatial_crop_tensor, |
| "prompt_token_count": prompt_token_count, |
| } |
|
|
| def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]: |
| """ |
| Collate batch of samples. |
| |
| Expected feature format: |
| { |
| "prompt": str, # The user's question/instruction |
| "response": str, # The assistant's response |
| "image": PIL.Image or bytes dict # The image |
| } |
| |
| This will be converted to Qwen2.5-VL native conversation format: |
| [ |
| { |
| "role": "user", |
| "content": [ |
| {"type": "image", "image": <PIL.Image>}, |
| {"type": "text", "text": "<prompt>"} |
| ] |
| }, |
| { |
| "role": "assistant", |
| "content": [{"type": "text", "text": "<response>"}] |
| } |
| ] |
| """ |
| batch_data = [] |
|
|
| |
| for feature in features: |
| try: |
| |
| image_data = feature.get('image') or feature.get('image_path') |
| if image_data is None: |
| raise ValueError("Sample missing both 'image' and 'image_path' keys") |
|
|
| |
| |
| messages = [ |
| { |
| "role": "user", |
| "content": [ |
| {"type": "image", "image": image_data}, |
| {"type": "text", "text": feature['prompt']} |
| ] |
| }, |
| { |
| "role": "assistant", |
| "content": [ |
| {"type": "text", "text": feature["response"]} |
| ] |
| } |
| ] |
| |
| processed = self.process_single_sample(messages) |
| batch_data.append(processed) |
| except Exception as e: |
| print(f"Error processing sample: {e}") |
| continue |
|
|
| if not batch_data: |
| raise ValueError("No valid samples in batch") |
|
|
| |
| input_ids_list = [item['input_ids'] for item in batch_data] |
| images_seq_mask_list = [item['images_seq_mask'] for item in batch_data] |
| prompt_token_counts = [item['prompt_token_count'] for item in batch_data] |
|
|
| |
| input_ids = pad_sequence(input_ids_list, batch_first=True, padding_value=self.pad_token_id) |
| images_seq_mask = pad_sequence(images_seq_mask_list, batch_first=True, padding_value=False) |
|
|
| |
| if self.max_length is not None and input_ids.shape[1] > self.max_length: |
| input_ids = input_ids[:, :self.max_length] |
| images_seq_mask = images_seq_mask[:, :self.max_length] |
| |
| prompt_token_counts = [min(p, self.max_length) for p in prompt_token_counts] |
|
|
| |
| labels = input_ids.clone() |
|
|
| |
| labels[labels == self.pad_token_id] = -100 |
|
|
| |
| labels[images_seq_mask] = -100 |
|
|
| |
| if self.train_on_responses_only: |
| for idx, prompt_count in enumerate(prompt_token_counts): |
| if prompt_count > 0: |
| labels[idx, :prompt_count] = -100 |
|
|
| |
| attention_mask = (input_ids != self.pad_token_id).long() |
|
|
| images_batch = [] |
| for item in batch_data: |
| images_batch.append((item['images_crop'], item['images_ori'])) |
|
|
| images_spatial_crop = torch.cat([item['images_spatial_crop'] for item in batch_data], dim=0) |
|
|
| return { |
| "input_ids": input_ids, |
| "attention_mask": attention_mask, |
| "labels": labels, |
| "images": images_batch, |
| "images_seq_mask": images_seq_mask, |
| "images_spatial_crop": images_spatial_crop, |
| } |
|
|
|
|