| import json |
| from pathlib import Path |
| from typing import Any, Dict, List, Optional, Sequence, Tuple, Union |
|
|
| import numpy as np |
| import torch |
| from transformers.modeling_outputs import BaseModelOutputWithPooling |
| from transformers.processing_utils import BatchFeature |
|
|
| from .configuration_paddleocr_vl import PaddleOCRVLConfig |
| from .image_processing_paddleocr_vl import PaddleOCRVLImageProcessor |
| from .modeling_paddleocr_vl import PaddleOCRVisionModel, Projector |
|
|
|
|
| VISION_TOWER_CONFIG_NAME = "vision_tower_config.json" |
| VISION_TOWER_WEIGHTS_NAME = "vision_tower.safetensors" |
| PROJECTOR_CONFIG_NAME = "projector_config.json" |
| PROJECTOR_WEIGHTS_NAME = "projector.safetensors" |
| FULL_MODEL_CONFIG_NAME = "config.json" |
| FULL_MODEL_WEIGHTS_NAME = "model.safetensors" |
| FULL_VISUAL_PREFIX = "visual." |
| FULL_PROJECTOR_PREFIX = "mlp_AR." |
| STANDALONE_VISUAL_PREFIX = "visual." |
| STANDALONE_PROJECTOR_PREFIX = "projector." |
| IMAGE_PROCESSOR_TEMPORAL_PATCH_SIZE = 1 |
|
|
|
|
| def _read_json(path: Union[str, Path]) -> Dict[str, Any]: |
| with open(path, "r", encoding="utf-8") as f: |
| return json.load(f) |
|
|
|
|
| def _write_json(path: Union[str, Path], payload: Dict[str, Any]) -> None: |
| with open(path, "w", encoding="utf-8") as f: |
| json.dump(payload, f, indent=2, ensure_ascii=False) |
|
|
|
|
| def _normalize_image_grid_thw( |
| image_grid_thw: Union[torch.Tensor, Sequence[Any]] |
| ) -> List[Tuple[int, int, int]]: |
| if isinstance(image_grid_thw, torch.Tensor): |
| return [tuple(int(v) for v in row.tolist()) for row in image_grid_thw] |
|
|
| normalized: List[Tuple[int, int, int]] = [] |
| for item in image_grid_thw: |
| if isinstance(item, torch.Tensor): |
| normalized.append(tuple(int(v) for v in item.tolist())) |
| else: |
| normalized.append(tuple(int(v) for v in item)) |
| return normalized |
|
|
|
|
| def build_vision_encoder_export_config( |
| full_config: Union[PaddleOCRVLConfig, Dict[str, Any]] |
| ) -> Dict[str, Any]: |
| if isinstance(full_config, PaddleOCRVLConfig): |
| full_config_dict = full_config.to_dict() |
| else: |
| full_config_dict = dict(full_config) |
|
|
| vision_config = dict(full_config_dict["vision_config"]) |
|
|
| return { |
| "model_type": "paddleocr_vl_vision_encoder", |
| "architectures": ["PaddleOCRVLVisionEncoder"], |
| "source_model_type": full_config_dict.get("model_type", "paddleocr_vl"), |
| "source_architecture": "PaddleOCRVLForConditionalGeneration", |
| "text_hidden_size": full_config_dict["hidden_size"], |
| "image_token_id": full_config_dict.get("image_token_id"), |
| "vision_start_token_id": full_config_dict.get("vision_start_token_id"), |
| "vision_end_token_id": full_config_dict.get("vision_end_token_id"), |
| "torch_dtype": full_config_dict.get("torch_dtype"), |
| "vision_config": vision_config, |
| "projector": { |
| "merge_kernel_size": [2, 2], |
| "input_hidden_size": vision_config["hidden_size"], |
| "output_hidden_size": full_config_dict["hidden_size"], |
| }, |
| "required_weight_prefixes": [ |
| STANDALONE_VISUAL_PREFIX, |
| STANDALONE_PROJECTOR_PREFIX, |
| ], |
| "source_weight_prefixes": { |
| "visual": FULL_VISUAL_PREFIX, |
| "projector": FULL_PROJECTOR_PREFIX, |
| }, |
| "full_model_config": full_config_dict, |
| } |
|
|
|
|
| def build_vision_tower_export_config( |
| full_config: Union[PaddleOCRVLConfig, Dict[str, Any]] |
| ) -> Dict[str, Any]: |
| combined = build_vision_encoder_export_config(full_config) |
| return { |
| "model_type": "paddleocr_vl_vision_tower", |
| "architectures": ["PaddleOCRVLVisionTower"], |
| "torch_dtype": combined.get("torch_dtype"), |
| "vision_config": combined["vision_config"], |
| "required_weight_prefixes": [STANDALONE_VISUAL_PREFIX], |
| "source_weight_prefixes": {"visual": FULL_VISUAL_PREFIX}, |
| "full_model_config": combined["full_model_config"], |
| } |
|
|
|
|
| def build_projector_export_config( |
| full_config: Union[PaddleOCRVLConfig, Dict[str, Any]] |
| ) -> Dict[str, Any]: |
| combined = build_vision_encoder_export_config(full_config) |
| return { |
| "model_type": "paddleocr_vl_projector", |
| "architectures": ["PaddleOCRVLProjector"], |
| "torch_dtype": combined.get("torch_dtype"), |
| "vision_config": combined["vision_config"], |
| "text_hidden_size": combined["text_hidden_size"], |
| "projector": combined["projector"], |
| "required_weight_prefixes": [STANDALONE_PROJECTOR_PREFIX], |
| "source_weight_prefixes": {"projector": FULL_PROJECTOR_PREFIX}, |
| "full_model_config": combined["full_model_config"], |
| } |
|
|
|
|
| def remap_full_model_state_dict_to_vision_encoder_parts( |
| full_state_dict: Dict[str, torch.Tensor] |
| ) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor], Dict[str, List[str]]]: |
| visual_state_dict: Dict[str, torch.Tensor] = {} |
| projector_state_dict: Dict[str, torch.Tensor] = {} |
| consumed_visual: List[str] = [] |
| consumed_projector: List[str] = [] |
|
|
| for key, value in full_state_dict.items(): |
| if key.startswith(FULL_VISUAL_PREFIX): |
| new_key = STANDALONE_VISUAL_PREFIX + key[len(FULL_VISUAL_PREFIX) :] |
| visual_state_dict[new_key] = value |
| consumed_visual.append(key) |
| elif key.startswith(FULL_PROJECTOR_PREFIX): |
| new_key = STANDALONE_PROJECTOR_PREFIX + key[len(FULL_PROJECTOR_PREFIX) :] |
| projector_state_dict[new_key] = value |
| consumed_projector.append(key) |
|
|
| if not consumed_visual: |
| raise ValueError("No visual.* weights were found in the full model state dict.") |
| if not consumed_projector: |
| raise ValueError("No mlp_AR.* weights were found in the full model state dict.") |
|
|
| return visual_state_dict, projector_state_dict, { |
| "visual": sorted(consumed_visual), |
| "projector": sorted(consumed_projector), |
| } |
|
|
|
|
| def remap_full_model_state_dict_to_vision_encoder( |
| full_state_dict: Dict[str, torch.Tensor] |
| ) -> Tuple[Dict[str, torch.Tensor], Dict[str, List[str]]]: |
| visual_state_dict, projector_state_dict, consumed = ( |
| remap_full_model_state_dict_to_vision_encoder_parts(full_state_dict) |
| ) |
| remapped = {} |
| remapped.update(visual_state_dict) |
| remapped.update(projector_state_dict) |
| return remapped, consumed |
|
|
|
|
| def _load_safetensors_state_dict(path: Union[str, Path]) -> Dict[str, torch.Tensor]: |
| try: |
| from safetensors.torch import load_file |
| except ImportError as e: |
| raise RuntimeError( |
| "Loading safetensors requires the `safetensors` package to be installed." |
| ) from e |
|
|
| return load_file(str(path)) |
|
|
|
|
| def _save_safetensors_state_dict( |
| state_dict: Dict[str, torch.Tensor], path: Union[str, Path] |
| ) -> None: |
| try: |
| from safetensors.torch import save_file |
| except ImportError as e: |
| raise RuntimeError( |
| "Saving safetensors requires the `safetensors` package to be installed." |
| ) from e |
|
|
| save_file(state_dict, str(path)) |
|
|
|
|
| def extract_and_save_vision_encoder_artifacts( |
| full_config: Union[PaddleOCRVLConfig, Dict[str, Any]], |
| full_state_dict: Dict[str, torch.Tensor], |
| output_dir: Union[str, Path], |
| ) -> Dict[str, Any]: |
| output_dir = Path(output_dir) |
| output_dir.mkdir(parents=True, exist_ok=True) |
|
|
| vision_tower_config = build_vision_tower_export_config(full_config) |
| projector_config = build_projector_export_config(full_config) |
| visual_state_dict, projector_state_dict, consumed = ( |
| remap_full_model_state_dict_to_vision_encoder_parts(full_state_dict) |
| ) |
| _save_safetensors_state_dict( |
| visual_state_dict, output_dir / VISION_TOWER_WEIGHTS_NAME |
| ) |
| _write_json(output_dir / VISION_TOWER_CONFIG_NAME, vision_tower_config) |
| _save_safetensors_state_dict( |
| projector_state_dict, output_dir / PROJECTOR_WEIGHTS_NAME |
| ) |
| _write_json(output_dir / PROJECTOR_CONFIG_NAME, projector_config) |
|
|
| combined_export_config = build_vision_encoder_export_config(full_config) |
| combined_state_dict, _ = remap_full_model_state_dict_to_vision_encoder( |
| full_state_dict |
| ) |
| combined_dir = output_dir / "combined" |
| combined_dir.mkdir(parents=True, exist_ok=True) |
| _save_safetensors_state_dict( |
| combined_state_dict, combined_dir / "vision_encoder.safetensors" |
| ) |
| _write_json(combined_dir / "vision_encoder_config.json", combined_export_config) |
|
|
| metadata = { |
| "vision_tower_config_path": str(output_dir / VISION_TOWER_CONFIG_NAME), |
| "vision_tower_weights_path": str(output_dir / VISION_TOWER_WEIGHTS_NAME), |
| "projector_config_path": str(output_dir / PROJECTOR_CONFIG_NAME), |
| "projector_weights_path": str(output_dir / PROJECTOR_WEIGHTS_NAME), |
| "combined_config_path": str(combined_dir / "vision_encoder_config.json"), |
| "combined_weights_path": str(combined_dir / "vision_encoder.safetensors"), |
| "num_exported_visual_tensors": len(visual_state_dict), |
| "num_exported_projector_tensors": len(projector_state_dict), |
| "consumed_full_model_keys": consumed, |
| } |
| return metadata |
|
|
|
|
| class PaddleOCRVLVisionTower(torch.nn.Module): |
| def __init__(self, config: PaddleOCRVLConfig): |
| super().__init__() |
| self.config = config |
| self.visual = PaddleOCRVisionModel(config.vision_config) |
| self.export_config = build_vision_tower_export_config(config) |
|
|
| @staticmethod |
| def _resolve_full_config(config_payload: Dict[str, Any]) -> PaddleOCRVLConfig: |
| if config_payload.get("model_type") == "paddleocr_vl_vision_tower": |
| config_payload = config_payload["full_model_config"] |
| return PaddleOCRVLConfig(**config_payload) |
|
|
| @classmethod |
| def from_pretrained(cls, model_dir: Union[str, Path]) -> "PaddleOCRVLVisionTower": |
| model_dir = Path(model_dir) |
| config_path = model_dir / VISION_TOWER_CONFIG_NAME |
| weights_path = model_dir / VISION_TOWER_WEIGHTS_NAME |
| if config_path.exists(): |
| config_payload = _read_json(config_path) |
| else: |
| config_payload = _read_json(model_dir / FULL_MODEL_CONFIG_NAME) |
| model = cls(cls._resolve_full_config(config_payload)) |
| if weights_path.exists(): |
| state_dict = _load_safetensors_state_dict(weights_path) |
| else: |
| full_state_dict = _load_safetensors_state_dict(model_dir / FULL_MODEL_WEIGHTS_NAME) |
| state_dict, _, _ = remap_full_model_state_dict_to_vision_encoder_parts( |
| full_state_dict |
| ) |
| missing, unexpected = model.load_state_dict(state_dict, strict=True) |
| if missing or unexpected: |
| raise RuntimeError( |
| f"Failed to load standalone vision tower weights. Missing: {missing}, unexpected: {unexpected}" |
| ) |
| return model |
|
|
| def save_pretrained(self, output_dir: Union[str, Path]) -> None: |
| output_dir = Path(output_dir) |
| output_dir.mkdir(parents=True, exist_ok=True) |
| _save_safetensors_state_dict(self.state_dict(), output_dir / VISION_TOWER_WEIGHTS_NAME) |
| _write_json(output_dir / VISION_TOWER_CONFIG_NAME, self.export_config) |
|
|
| @staticmethod |
| def _build_visual_inputs( |
| pixel_values: torch.Tensor, |
| image_grid_thw: List[Tuple[int, int, int]], |
| device: torch.device, |
| ) -> Tuple[ |
| torch.Tensor, |
| torch.Tensor, |
| List[Tuple[int, int, int]], |
| torch.Tensor, |
| torch.Tensor, |
| ]: |
| if pixel_values.dim() == 4: |
| pixel_values = pixel_values.unsqueeze(0) |
| elif pixel_values.dim() != 5: |
| raise ValueError( |
| "pixel_values must have shape [num_patches, C, H, W] or [1, num_patches, C, H, W]." |
| ) |
|
|
| siglip_position_ids = [] |
| sample_indices = [] |
| cu_seqlens = [0] |
|
|
| for idx, thw in enumerate(image_grid_thw): |
| numel = int(np.prod(thw)) |
| image_position_ids = torch.arange(numel, device=device) % int(np.prod(thw[1:])) |
| siglip_position_ids.append(image_position_ids) |
| sample_indices.append(torch.full((numel,), idx, dtype=torch.int64, device=device)) |
| cu_seqlens.append(cu_seqlens[-1] + numel) |
|
|
| if siglip_position_ids: |
| siglip_position_ids = torch.cat(siglip_position_ids, dim=0) |
| sample_indices = torch.cat(sample_indices, dim=0) |
| else: |
| siglip_position_ids = torch.empty(0, dtype=torch.long, device=device) |
| sample_indices = torch.empty(0, dtype=torch.long, device=device) |
|
|
| cu_seqlens_tensor = torch.tensor(cu_seqlens, dtype=torch.int32, device=device) |
| return pixel_values, siglip_position_ids, image_grid_thw, sample_indices, cu_seqlens_tensor |
|
|
| def forward( |
| self, |
| pixel_values: torch.Tensor, |
| image_grid_thw: Union[torch.Tensor, Sequence[Any]], |
| ) -> Dict[str, Any]: |
| image_grid_thw_list = _normalize_image_grid_thw(image_grid_thw) |
| vision_dtype = next(self.visual.parameters()).dtype |
| pixel_values = pixel_values.to(dtype=vision_dtype) |
| device = pixel_values.device |
|
|
| ( |
| pixel_values_5d, |
| siglip_position_ids, |
| image_grid_hws, |
| sample_indices, |
| cu_seqlens, |
| ) = self._build_visual_inputs(pixel_values, image_grid_thw_list, device) |
|
|
| vision_outputs: BaseModelOutputWithPooling = self.visual( |
| pixel_values=pixel_values_5d, |
| image_grid_thw=image_grid_hws, |
| position_ids=siglip_position_ids, |
| vision_return_embed_list=True, |
| interpolate_pos_encoding=True, |
| sample_indices=sample_indices, |
| cu_seqlens=cu_seqlens, |
| return_pooler_output=False, |
| use_rope=True, |
| window_size=-1, |
| ) |
| return { |
| "visual_embeds": vision_outputs.last_hidden_state, |
| "image_grid_thw": image_grid_thw_list, |
| "siglip_position_ids": siglip_position_ids, |
| "sample_indices": sample_indices, |
| "cu_seqlens": cu_seqlens, |
| } |
|
|
| def encode_images( |
| self, |
| images: Any, |
| image_processor: Optional[PaddleOCRVLImageProcessor] = None, |
| **processor_kwargs: Any, |
| ) -> Dict[str, Any]: |
| image_processor = image_processor or PaddleOCRVLImageProcessor( |
| patch_size=self.config.vision_config.patch_size, |
| |
| |
| temporal_patch_size=IMAGE_PROCESSOR_TEMPORAL_PATCH_SIZE, |
| merge_size=self.config.vision_config.spatial_merge_size, |
| ) |
| encoded: BatchFeature = image_processor( |
| images=images, return_tensors="pt", **processor_kwargs |
| ) |
| return self.forward( |
| pixel_values=encoded["pixel_values"], image_grid_thw=encoded["image_grid_thw"] |
| ) |
|
|
|
|
| class PaddleOCRVLProjector(torch.nn.Module): |
| def __init__(self, config: PaddleOCRVLConfig): |
| super().__init__() |
| self.config = config |
| self.projector = Projector(config, config.vision_config) |
| self.export_config = build_projector_export_config(config) |
|
|
| @staticmethod |
| def _resolve_full_config(config_payload: Dict[str, Any]) -> PaddleOCRVLConfig: |
| if config_payload.get("model_type") == "paddleocr_vl_projector": |
| config_payload = config_payload["full_model_config"] |
| return PaddleOCRVLConfig(**config_payload) |
|
|
| @classmethod |
| def from_pretrained(cls, model_dir: Union[str, Path]) -> "PaddleOCRVLProjector": |
| model_dir = Path(model_dir) |
| config_path = model_dir / PROJECTOR_CONFIG_NAME |
| weights_path = model_dir / PROJECTOR_WEIGHTS_NAME |
|
|
| if config_path.exists(): |
| config_payload = _read_json(config_path) |
| else: |
| config_payload = _read_json(model_dir / FULL_MODEL_CONFIG_NAME) |
|
|
| model = cls(cls._resolve_full_config(config_payload)) |
|
|
| if weights_path.exists(): |
| state_dict = _load_safetensors_state_dict(weights_path) |
| else: |
| full_state_dict = _load_safetensors_state_dict(model_dir / FULL_MODEL_WEIGHTS_NAME) |
| _, state_dict, _ = remap_full_model_state_dict_to_vision_encoder_parts( |
| full_state_dict |
| ) |
|
|
| missing, unexpected = model.load_state_dict(state_dict, strict=True) |
| if missing or unexpected: |
| raise RuntimeError( |
| f"Failed to load standalone projector weights. Missing: {missing}, unexpected: {unexpected}" |
| ) |
| return model |
|
|
| def save_pretrained(self, output_dir: Union[str, Path]) -> None: |
| output_dir = Path(output_dir) |
| output_dir.mkdir(parents=True, exist_ok=True) |
| _save_safetensors_state_dict(self.state_dict(), output_dir / PROJECTOR_WEIGHTS_NAME) |
| _write_json(output_dir / PROJECTOR_CONFIG_NAME, self.export_config) |
|
|
| def forward( |
| self, |
| visual_embeds: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]], |
| image_grid_thw: Union[torch.Tensor, Sequence[Any]], |
| ) -> Dict[str, Any]: |
| image_grid_thw_list = _normalize_image_grid_thw(image_grid_thw) |
| image_embeds = self.projector(visual_embeds, image_grid_thw_list) |
| projector_dtype = next(self.projector.parameters()).dtype |
| projector_device = next(self.projector.parameters()).device |
| concat_image_embeds = ( |
| torch.cat(image_embeds, dim=0) |
| if image_embeds |
| else torch.empty( |
| 0, |
| self.config.hidden_size, |
| device=projector_device, |
| dtype=projector_dtype, |
| ) |
| ) |
| return { |
| "image_embeds": image_embeds, |
| "concat_image_embeds": concat_image_embeds, |
| "image_grid_thw": image_grid_thw_list, |
| } |
|
|
| class PaddleOCRVLVisionEncoder(torch.nn.Module): |
| def __init__(self, config: PaddleOCRVLConfig): |
| super().__init__() |
| self.config = config |
| self.vision_tower = PaddleOCRVLVisionTower(config) |
| self.projector = PaddleOCRVLProjector(config) |
| self.export_config = build_vision_encoder_export_config(config) |
|
|
| @classmethod |
| def from_pretrained(cls, model_dir: Union[str, Path]) -> "PaddleOCRVLVisionEncoder": |
| model_dir = Path(model_dir) |
| config_candidates = [ |
| model_dir / FULL_MODEL_CONFIG_NAME, |
| model_dir / VISION_TOWER_CONFIG_NAME, |
| model_dir / PROJECTOR_CONFIG_NAME, |
| ] |
| config_path = next((path for path in config_candidates if path.exists()), None) |
| if config_path is None: |
| raise FileNotFoundError( |
| "Could not find config.json, vision_tower_config.json, or projector_config.json." |
| ) |
| config_payload = _read_json(config_path) |
| if config_payload.get("model_type") == "paddleocr_vl_vision_tower": |
| config = PaddleOCRVLVisionTower._resolve_full_config(config_payload) |
| elif config_payload.get("model_type") == "paddleocr_vl_projector": |
| config = PaddleOCRVLProjector._resolve_full_config(config_payload) |
| else: |
| config = PaddleOCRVLProjector._resolve_full_config(config_payload) |
| model = cls(config) |
| model.vision_tower = PaddleOCRVLVisionTower.from_pretrained(model_dir) |
| model.projector = PaddleOCRVLProjector.from_pretrained(model_dir) |
| return model |
|
|
| def forward( |
| self, |
| pixel_values: torch.Tensor, |
| image_grid_thw: Union[torch.Tensor, Sequence[Any]], |
| ) -> Dict[str, Any]: |
| vision_outputs = self.vision_tower( |
| pixel_values=pixel_values, |
| image_grid_thw=image_grid_thw, |
| ) |
| projector_outputs = self.projector( |
| visual_embeds=vision_outputs["visual_embeds"], |
| image_grid_thw=vision_outputs["image_grid_thw"], |
| ) |
| return { |
| **vision_outputs, |
| **projector_outputs, |
| } |
|
|
| def encode_images( |
| self, |
| images: Any, |
| image_processor: Optional[PaddleOCRVLImageProcessor] = None, |
| **processor_kwargs: Any, |
| ) -> Dict[str, Any]: |
| vision_outputs = self.vision_tower.encode_images( |
| images=images, |
| image_processor=image_processor, |
| **processor_kwargs, |
| ) |
| projector_outputs = self.projector( |
| visual_embeds=vision_outputs["visual_embeds"], |
| image_grid_thw=vision_outputs["image_grid_thw"], |
| ) |
| return {**vision_outputs, **projector_outputs} |
|
|