import json from pathlib import Path from typing import Any, Dict, List, Optional, Sequence, Tuple, Union import numpy as np import torch from transformers.modeling_outputs import BaseModelOutputWithPooling from transformers.processing_utils import BatchFeature from .configuration_paddleocr_vl import PaddleOCRVLConfig from .image_processing_paddleocr_vl import PaddleOCRVLImageProcessor from .modeling_paddleocr_vl import PaddleOCRVisionModel, Projector VISION_TOWER_CONFIG_NAME = "vision_tower_config.json" VISION_TOWER_WEIGHTS_NAME = "vision_tower.safetensors" PROJECTOR_CONFIG_NAME = "projector_config.json" PROJECTOR_WEIGHTS_NAME = "projector.safetensors" FULL_MODEL_CONFIG_NAME = "config.json" FULL_MODEL_WEIGHTS_NAME = "model.safetensors" FULL_VISUAL_PREFIX = "visual." FULL_PROJECTOR_PREFIX = "mlp_AR." STANDALONE_VISUAL_PREFIX = "visual." STANDALONE_PROJECTOR_PREFIX = "projector." IMAGE_PROCESSOR_TEMPORAL_PATCH_SIZE = 1 def _read_json(path: Union[str, Path]) -> Dict[str, Any]: with open(path, "r", encoding="utf-8") as f: return json.load(f) def _write_json(path: Union[str, Path], payload: Dict[str, Any]) -> None: with open(path, "w", encoding="utf-8") as f: json.dump(payload, f, indent=2, ensure_ascii=False) def _normalize_image_grid_thw( image_grid_thw: Union[torch.Tensor, Sequence[Any]] ) -> List[Tuple[int, int, int]]: if isinstance(image_grid_thw, torch.Tensor): return [tuple(int(v) for v in row.tolist()) for row in image_grid_thw] normalized: List[Tuple[int, int, int]] = [] for item in image_grid_thw: if isinstance(item, torch.Tensor): normalized.append(tuple(int(v) for v in item.tolist())) else: normalized.append(tuple(int(v) for v in item)) return normalized def build_vision_encoder_export_config( full_config: Union[PaddleOCRVLConfig, Dict[str, Any]] ) -> Dict[str, Any]: if isinstance(full_config, PaddleOCRVLConfig): full_config_dict = full_config.to_dict() else: full_config_dict = dict(full_config) vision_config = dict(full_config_dict["vision_config"]) return { "model_type": "paddleocr_vl_vision_encoder", "architectures": ["PaddleOCRVLVisionEncoder"], "source_model_type": full_config_dict.get("model_type", "paddleocr_vl"), "source_architecture": "PaddleOCRVLForConditionalGeneration", "text_hidden_size": full_config_dict["hidden_size"], "image_token_id": full_config_dict.get("image_token_id"), "vision_start_token_id": full_config_dict.get("vision_start_token_id"), "vision_end_token_id": full_config_dict.get("vision_end_token_id"), "torch_dtype": full_config_dict.get("torch_dtype"), "vision_config": vision_config, "projector": { "merge_kernel_size": [2, 2], "input_hidden_size": vision_config["hidden_size"], "output_hidden_size": full_config_dict["hidden_size"], }, "required_weight_prefixes": [ STANDALONE_VISUAL_PREFIX, STANDALONE_PROJECTOR_PREFIX, ], "source_weight_prefixes": { "visual": FULL_VISUAL_PREFIX, "projector": FULL_PROJECTOR_PREFIX, }, "full_model_config": full_config_dict, } def build_vision_tower_export_config( full_config: Union[PaddleOCRVLConfig, Dict[str, Any]] ) -> Dict[str, Any]: combined = build_vision_encoder_export_config(full_config) return { "model_type": "paddleocr_vl_vision_tower", "architectures": ["PaddleOCRVLVisionTower"], "torch_dtype": combined.get("torch_dtype"), "vision_config": combined["vision_config"], "required_weight_prefixes": [STANDALONE_VISUAL_PREFIX], "source_weight_prefixes": {"visual": FULL_VISUAL_PREFIX}, "full_model_config": combined["full_model_config"], } def build_projector_export_config( full_config: Union[PaddleOCRVLConfig, Dict[str, Any]] ) -> Dict[str, Any]: combined = build_vision_encoder_export_config(full_config) return { "model_type": "paddleocr_vl_projector", "architectures": ["PaddleOCRVLProjector"], "torch_dtype": combined.get("torch_dtype"), "vision_config": combined["vision_config"], "text_hidden_size": combined["text_hidden_size"], "projector": combined["projector"], "required_weight_prefixes": [STANDALONE_PROJECTOR_PREFIX], "source_weight_prefixes": {"projector": FULL_PROJECTOR_PREFIX}, "full_model_config": combined["full_model_config"], } def remap_full_model_state_dict_to_vision_encoder_parts( full_state_dict: Dict[str, torch.Tensor] ) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor], Dict[str, List[str]]]: visual_state_dict: Dict[str, torch.Tensor] = {} projector_state_dict: Dict[str, torch.Tensor] = {} consumed_visual: List[str] = [] consumed_projector: List[str] = [] for key, value in full_state_dict.items(): if key.startswith(FULL_VISUAL_PREFIX): new_key = STANDALONE_VISUAL_PREFIX + key[len(FULL_VISUAL_PREFIX) :] visual_state_dict[new_key] = value consumed_visual.append(key) elif key.startswith(FULL_PROJECTOR_PREFIX): new_key = STANDALONE_PROJECTOR_PREFIX + key[len(FULL_PROJECTOR_PREFIX) :] projector_state_dict[new_key] = value consumed_projector.append(key) if not consumed_visual: raise ValueError("No visual.* weights were found in the full model state dict.") if not consumed_projector: raise ValueError("No mlp_AR.* weights were found in the full model state dict.") return visual_state_dict, projector_state_dict, { "visual": sorted(consumed_visual), "projector": sorted(consumed_projector), } def remap_full_model_state_dict_to_vision_encoder( full_state_dict: Dict[str, torch.Tensor] ) -> Tuple[Dict[str, torch.Tensor], Dict[str, List[str]]]: visual_state_dict, projector_state_dict, consumed = ( remap_full_model_state_dict_to_vision_encoder_parts(full_state_dict) ) remapped = {} remapped.update(visual_state_dict) remapped.update(projector_state_dict) return remapped, consumed def _load_safetensors_state_dict(path: Union[str, Path]) -> Dict[str, torch.Tensor]: try: from safetensors.torch import load_file except ImportError as e: raise RuntimeError( "Loading safetensors requires the `safetensors` package to be installed." ) from e return load_file(str(path)) def _save_safetensors_state_dict( state_dict: Dict[str, torch.Tensor], path: Union[str, Path] ) -> None: try: from safetensors.torch import save_file except ImportError as e: raise RuntimeError( "Saving safetensors requires the `safetensors` package to be installed." ) from e save_file(state_dict, str(path)) def extract_and_save_vision_encoder_artifacts( full_config: Union[PaddleOCRVLConfig, Dict[str, Any]], full_state_dict: Dict[str, torch.Tensor], output_dir: Union[str, Path], ) -> Dict[str, Any]: output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) vision_tower_config = build_vision_tower_export_config(full_config) projector_config = build_projector_export_config(full_config) visual_state_dict, projector_state_dict, consumed = ( remap_full_model_state_dict_to_vision_encoder_parts(full_state_dict) ) _save_safetensors_state_dict( visual_state_dict, output_dir / VISION_TOWER_WEIGHTS_NAME ) _write_json(output_dir / VISION_TOWER_CONFIG_NAME, vision_tower_config) _save_safetensors_state_dict( projector_state_dict, output_dir / PROJECTOR_WEIGHTS_NAME ) _write_json(output_dir / PROJECTOR_CONFIG_NAME, projector_config) combined_export_config = build_vision_encoder_export_config(full_config) combined_state_dict, _ = remap_full_model_state_dict_to_vision_encoder( full_state_dict ) combined_dir = output_dir / "combined" combined_dir.mkdir(parents=True, exist_ok=True) _save_safetensors_state_dict( combined_state_dict, combined_dir / "vision_encoder.safetensors" ) _write_json(combined_dir / "vision_encoder_config.json", combined_export_config) metadata = { "vision_tower_config_path": str(output_dir / VISION_TOWER_CONFIG_NAME), "vision_tower_weights_path": str(output_dir / VISION_TOWER_WEIGHTS_NAME), "projector_config_path": str(output_dir / PROJECTOR_CONFIG_NAME), "projector_weights_path": str(output_dir / PROJECTOR_WEIGHTS_NAME), "combined_config_path": str(combined_dir / "vision_encoder_config.json"), "combined_weights_path": str(combined_dir / "vision_encoder.safetensors"), "num_exported_visual_tensors": len(visual_state_dict), "num_exported_projector_tensors": len(projector_state_dict), "consumed_full_model_keys": consumed, } return metadata class PaddleOCRVLVisionTower(torch.nn.Module): def __init__(self, config: PaddleOCRVLConfig): super().__init__() self.config = config self.visual = PaddleOCRVisionModel(config.vision_config) self.export_config = build_vision_tower_export_config(config) @staticmethod def _resolve_full_config(config_payload: Dict[str, Any]) -> PaddleOCRVLConfig: if config_payload.get("model_type") == "paddleocr_vl_vision_tower": config_payload = config_payload["full_model_config"] return PaddleOCRVLConfig(**config_payload) @classmethod def from_pretrained(cls, model_dir: Union[str, Path]) -> "PaddleOCRVLVisionTower": model_dir = Path(model_dir) config_path = model_dir / VISION_TOWER_CONFIG_NAME weights_path = model_dir / VISION_TOWER_WEIGHTS_NAME if config_path.exists(): config_payload = _read_json(config_path) else: config_payload = _read_json(model_dir / FULL_MODEL_CONFIG_NAME) model = cls(cls._resolve_full_config(config_payload)) if weights_path.exists(): state_dict = _load_safetensors_state_dict(weights_path) else: full_state_dict = _load_safetensors_state_dict(model_dir / FULL_MODEL_WEIGHTS_NAME) state_dict, _, _ = remap_full_model_state_dict_to_vision_encoder_parts( full_state_dict ) missing, unexpected = model.load_state_dict(state_dict, strict=True) if missing or unexpected: raise RuntimeError( f"Failed to load standalone vision tower weights. Missing: {missing}, unexpected: {unexpected}" ) return model def save_pretrained(self, output_dir: Union[str, Path]) -> None: output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) _save_safetensors_state_dict(self.state_dict(), output_dir / VISION_TOWER_WEIGHTS_NAME) _write_json(output_dir / VISION_TOWER_CONFIG_NAME, self.export_config) @staticmethod def _build_visual_inputs( pixel_values: torch.Tensor, image_grid_thw: List[Tuple[int, int, int]], device: torch.device, ) -> Tuple[ torch.Tensor, torch.Tensor, List[Tuple[int, int, int]], torch.Tensor, torch.Tensor, ]: if pixel_values.dim() == 4: pixel_values = pixel_values.unsqueeze(0) elif pixel_values.dim() != 5: raise ValueError( "pixel_values must have shape [num_patches, C, H, W] or [1, num_patches, C, H, W]." ) siglip_position_ids = [] sample_indices = [] cu_seqlens = [0] for idx, thw in enumerate(image_grid_thw): numel = int(np.prod(thw)) image_position_ids = torch.arange(numel, device=device) % int(np.prod(thw[1:])) siglip_position_ids.append(image_position_ids) sample_indices.append(torch.full((numel,), idx, dtype=torch.int64, device=device)) cu_seqlens.append(cu_seqlens[-1] + numel) if siglip_position_ids: siglip_position_ids = torch.cat(siglip_position_ids, dim=0) sample_indices = torch.cat(sample_indices, dim=0) else: siglip_position_ids = torch.empty(0, dtype=torch.long, device=device) sample_indices = torch.empty(0, dtype=torch.long, device=device) cu_seqlens_tensor = torch.tensor(cu_seqlens, dtype=torch.int32, device=device) return pixel_values, siglip_position_ids, image_grid_thw, sample_indices, cu_seqlens_tensor def forward( self, pixel_values: torch.Tensor, image_grid_thw: Union[torch.Tensor, Sequence[Any]], ) -> Dict[str, Any]: image_grid_thw_list = _normalize_image_grid_thw(image_grid_thw) vision_dtype = next(self.visual.parameters()).dtype pixel_values = pixel_values.to(dtype=vision_dtype) device = pixel_values.device ( pixel_values_5d, siglip_position_ids, image_grid_hws, sample_indices, cu_seqlens, ) = self._build_visual_inputs(pixel_values, image_grid_thw_list, device) vision_outputs: BaseModelOutputWithPooling = self.visual( pixel_values=pixel_values_5d, image_grid_thw=image_grid_hws, position_ids=siglip_position_ids, vision_return_embed_list=True, interpolate_pos_encoding=True, sample_indices=sample_indices, cu_seqlens=cu_seqlens, return_pooler_output=False, use_rope=True, window_size=-1, ) return { "visual_embeds": vision_outputs.last_hidden_state, "image_grid_thw": image_grid_thw_list, "siglip_position_ids": siglip_position_ids, "sample_indices": sample_indices, "cu_seqlens": cu_seqlens, } def encode_images( self, images: Any, image_processor: Optional[PaddleOCRVLImageProcessor] = None, **processor_kwargs: Any, ) -> Dict[str, Any]: image_processor = image_processor or PaddleOCRVLImageProcessor( patch_size=self.config.vision_config.patch_size, # The current image preprocessing implementation is image-only and asserts # `temporal_patch_size == 1`, even though the vision model config may store 2. temporal_patch_size=IMAGE_PROCESSOR_TEMPORAL_PATCH_SIZE, merge_size=self.config.vision_config.spatial_merge_size, ) encoded: BatchFeature = image_processor( images=images, return_tensors="pt", **processor_kwargs ) return self.forward( pixel_values=encoded["pixel_values"], image_grid_thw=encoded["image_grid_thw"] ) class PaddleOCRVLProjector(torch.nn.Module): def __init__(self, config: PaddleOCRVLConfig): super().__init__() self.config = config self.projector = Projector(config, config.vision_config) self.export_config = build_projector_export_config(config) @staticmethod def _resolve_full_config(config_payload: Dict[str, Any]) -> PaddleOCRVLConfig: if config_payload.get("model_type") == "paddleocr_vl_projector": config_payload = config_payload["full_model_config"] return PaddleOCRVLConfig(**config_payload) @classmethod def from_pretrained(cls, model_dir: Union[str, Path]) -> "PaddleOCRVLProjector": model_dir = Path(model_dir) config_path = model_dir / PROJECTOR_CONFIG_NAME weights_path = model_dir / PROJECTOR_WEIGHTS_NAME if config_path.exists(): config_payload = _read_json(config_path) else: config_payload = _read_json(model_dir / FULL_MODEL_CONFIG_NAME) model = cls(cls._resolve_full_config(config_payload)) if weights_path.exists(): state_dict = _load_safetensors_state_dict(weights_path) else: full_state_dict = _load_safetensors_state_dict(model_dir / FULL_MODEL_WEIGHTS_NAME) _, state_dict, _ = remap_full_model_state_dict_to_vision_encoder_parts( full_state_dict ) missing, unexpected = model.load_state_dict(state_dict, strict=True) if missing or unexpected: raise RuntimeError( f"Failed to load standalone projector weights. Missing: {missing}, unexpected: {unexpected}" ) return model def save_pretrained(self, output_dir: Union[str, Path]) -> None: output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) _save_safetensors_state_dict(self.state_dict(), output_dir / PROJECTOR_WEIGHTS_NAME) _write_json(output_dir / PROJECTOR_CONFIG_NAME, self.export_config) def forward( self, visual_embeds: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]], image_grid_thw: Union[torch.Tensor, Sequence[Any]], ) -> Dict[str, Any]: image_grid_thw_list = _normalize_image_grid_thw(image_grid_thw) image_embeds = self.projector(visual_embeds, image_grid_thw_list) projector_dtype = next(self.projector.parameters()).dtype projector_device = next(self.projector.parameters()).device concat_image_embeds = ( torch.cat(image_embeds, dim=0) if image_embeds else torch.empty( 0, self.config.hidden_size, device=projector_device, dtype=projector_dtype, ) ) return { "image_embeds": image_embeds, "concat_image_embeds": concat_image_embeds, "image_grid_thw": image_grid_thw_list, } class PaddleOCRVLVisionEncoder(torch.nn.Module): def __init__(self, config: PaddleOCRVLConfig): super().__init__() self.config = config self.vision_tower = PaddleOCRVLVisionTower(config) self.projector = PaddleOCRVLProjector(config) self.export_config = build_vision_encoder_export_config(config) @classmethod def from_pretrained(cls, model_dir: Union[str, Path]) -> "PaddleOCRVLVisionEncoder": model_dir = Path(model_dir) config_candidates = [ model_dir / FULL_MODEL_CONFIG_NAME, model_dir / VISION_TOWER_CONFIG_NAME, model_dir / PROJECTOR_CONFIG_NAME, ] config_path = next((path for path in config_candidates if path.exists()), None) if config_path is None: raise FileNotFoundError( "Could not find config.json, vision_tower_config.json, or projector_config.json." ) config_payload = _read_json(config_path) if config_payload.get("model_type") == "paddleocr_vl_vision_tower": config = PaddleOCRVLVisionTower._resolve_full_config(config_payload) elif config_payload.get("model_type") == "paddleocr_vl_projector": config = PaddleOCRVLProjector._resolve_full_config(config_payload) else: config = PaddleOCRVLProjector._resolve_full_config(config_payload) model = cls(config) model.vision_tower = PaddleOCRVLVisionTower.from_pretrained(model_dir) model.projector = PaddleOCRVLProjector.from_pretrained(model_dir) return model def forward( self, pixel_values: torch.Tensor, image_grid_thw: Union[torch.Tensor, Sequence[Any]], ) -> Dict[str, Any]: vision_outputs = self.vision_tower( pixel_values=pixel_values, image_grid_thw=image_grid_thw, ) projector_outputs = self.projector( visual_embeds=vision_outputs["visual_embeds"], image_grid_thw=vision_outputs["image_grid_thw"], ) return { **vision_outputs, **projector_outputs, } def encode_images( self, images: Any, image_processor: Optional[PaddleOCRVLImageProcessor] = None, **processor_kwargs: Any, ) -> Dict[str, Any]: vision_outputs = self.vision_tower.encode_images( images=images, image_processor=image_processor, **processor_kwargs, ) projector_outputs = self.projector( visual_embeds=vision_outputs["visual_embeds"], image_grid_thw=vision_outputs["image_grid_thw"], ) return {**vision_outputs, **projector_outputs}